From dd6276e706cbdb6bf841f610b1827054b07a89c6 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Thu, 10 Oct 2024 15:37:43 -0600 Subject: [PATCH] Embeddings fixes (#14269) * Add debugging logs for more info * Improve timeout handling * Fix event cleanup * Handle zmq error and empty data * Don't run download * Remove unneeded embeddings creations * Update timouts * Init models immediately * Fix order of init * Cleanup --- frigate/app.py | 4 +- frigate/comms/dispatcher.py | 7 ++-- frigate/comms/embeddings_updater.py | 9 +++-- frigate/comms/event_metadata_updater.py | 2 +- frigate/db/sqlitevecq.py | 23 +++++++++++ frigate/embeddings/__init__.py | 42 ++++++++++--------- frigate/embeddings/embeddings.py | 36 ++++------------- frigate/embeddings/functions/onnx.py | 42 ++++++++++++++----- frigate/embeddings/maintainer.py | 54 ++++++++++++++----------- frigate/events/cleanup.py | 13 ++---- frigate/util/downloader.py | 33 ++++++++++----- 11 files changed, 154 insertions(+), 111 deletions(-) diff --git a/frigate/app.py b/frigate/app.py index 1fcf91551..0cf76699c 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -581,12 +581,12 @@ class FrigateApp: self.init_recording_manager() self.init_review_segment_manager() self.init_go2rtc() + self.start_detectors() + self.init_embeddings_manager() self.bind_database() self.check_db_data_migrations() self.init_inter_process_communicator() self.init_dispatcher() - self.start_detectors() - self.init_embeddings_manager() self.init_embeddings_client() self.start_video_output_processor() self.start_ptz_autotracker() diff --git a/frigate/comms/dispatcher.py b/frigate/comms/dispatcher.py index c1a9f7e86..12dfe2731 100644 --- a/frigate/comms/dispatcher.py +++ b/frigate/comms/dispatcher.py @@ -64,6 +64,9 @@ class Dispatcher: self.onvif = onvif self.ptz_metrics = ptz_metrics self.comms = communicators + self.camera_activity = {} + self.model_state = {} + self.embeddings_reindex = {} self._camera_settings_handlers: dict[str, Callable] = { "audio": self._on_audio_command, @@ -85,10 +88,6 @@ class Dispatcher: for comm in self.comms: comm.subscribe(self._receive) - self.camera_activity = {} - self.model_state = {} - self.embeddings_reindex = {} - def _receive(self, topic: str, payload: str) -> Optional[Any]: """Handle receiving of payload from communicators.""" diff --git a/frigate/comms/embeddings_updater.py b/frigate/comms/embeddings_updater.py index 8a7617630..9a13525f8 100644 --- a/frigate/comms/embeddings_updater.py +++ b/frigate/comms/embeddings_updater.py @@ -22,7 +22,7 @@ class EmbeddingsResponder: def check_for_request(self, process: Callable) -> None: while True: # load all messages that are queued - has_message, _, _ = zmq.select([self.socket], [], [], 1) + has_message, _, _ = zmq.select([self.socket], [], [], 0.1) if not has_message: break @@ -54,8 +54,11 @@ class EmbeddingsRequestor: def send_data(self, topic: str, data: any) -> str: """Sends data and then waits for reply.""" - self.socket.send_json((topic, data)) - return self.socket.recv_json() + try: + self.socket.send_json((topic, data)) + return self.socket.recv_json() + except zmq.ZMQError: + return "" def stop(self) -> None: self.socket.close() diff --git a/frigate/comms/event_metadata_updater.py b/frigate/comms/event_metadata_updater.py index aeede6d8e..87e1889ce 100644 --- a/frigate/comms/event_metadata_updater.py +++ b/frigate/comms/event_metadata_updater.py @@ -39,7 +39,7 @@ class EventMetadataSubscriber(Subscriber): super().__init__(topic) def check_for_update( - self, timeout: float = None + self, timeout: float = 1 ) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]: return super().check_for_update(timeout) diff --git a/frigate/db/sqlitevecq.py b/frigate/db/sqlitevecq.py index 858070c38..398adbd2d 100644 --- a/frigate/db/sqlitevecq.py +++ b/frigate/db/sqlitevecq.py @@ -28,3 +28,26 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase): def delete_embeddings_description(self, event_ids: list[str]) -> None: ids = ",".join(["?" for _ in event_ids]) self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids) + + def drop_embeddings_tables(self) -> None: + self.execute_sql(""" + DROP TABLE vec_descriptions; + """) + self.execute_sql(""" + DROP TABLE vec_thumbnails; + """) + + def create_embeddings_tables(self) -> None: + """Create vec0 virtual table for embeddings""" + self.execute_sql(""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( + id TEXT PRIMARY KEY, + thumbnail_embedding FLOAT[768] + ); + """) + self.execute_sql(""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( + id TEXT PRIMARY KEY, + description_embedding FLOAT[768] + ); + """) diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index e7dcf1053..7f2e1a10c 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -19,7 +19,6 @@ from frigate.models import Event from frigate.util.builtin import serialize from frigate.util.services import listen -from .embeddings import Embeddings from .maintainer import EmbeddingMaintainer from .util import ZScoreNormalization @@ -57,12 +56,6 @@ def manage_embeddings(config: FrigateConfig) -> None: models = [Event] db.bind(models) - embeddings = Embeddings(config.semantic_search, db) - - # Check if we need to re-index events - if config.semantic_search.reindex: - embeddings.reindex() - maintainer = EmbeddingMaintainer( db, config, @@ -114,19 +107,25 @@ class EmbeddingsContext: query_embedding = row[0] else: # If no embedding found, generate it and return it - query_embedding = serialize( - self.requestor.send_data( - EmbeddingsRequestEnum.embed_thumbnail.value, - {"id": query.id, "thumbnail": query.thumbnail}, - ) + data = self.requestor.send_data( + EmbeddingsRequestEnum.embed_thumbnail.value, + {"id": str(query.id), "thumbnail": str(query.thumbnail)}, ) + + if not data: + return [] + + query_embedding = serialize(data) else: - query_embedding = serialize( - self.requestor.send_data( - EmbeddingsRequestEnum.generate_search.value, query - ) + data = self.requestor.send_data( + EmbeddingsRequestEnum.generate_search.value, query ) + if not data: + return [] + + query_embedding = serialize(data) + sql_query = """ SELECT id, @@ -155,12 +154,15 @@ class EmbeddingsContext: def search_description( self, query_text: str, event_ids: list[str] = None ) -> list[tuple[str, float]]: - query_embedding = serialize( - self.requestor.send_data( - EmbeddingsRequestEnum.generate_search.value, query_text - ) + data = self.requestor.send_data( + EmbeddingsRequestEnum.generate_search.value, query_text ) + if not data: + return [] + + query_embedding = serialize(data) + # Prepare the base SQL query sql_query = """ SELECT diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index dda4d95fd..e9d8ab833 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -63,7 +63,7 @@ class Embeddings: self.requestor = InterProcessRequestor() # Create tables if they don't exist - self._create_tables() + self.db.create_embeddings_tables() models = [ "jinaai/jina-clip-v1-text_model_fp16.onnx", @@ -96,6 +96,7 @@ class Embeddings: }, embedding_function=jina_text_embedding_function, model_type="text", + requestor=self.requestor, device="CPU", ) @@ -108,34 +109,10 @@ class Embeddings: }, embedding_function=jina_vision_embedding_function, model_type="vision", + requestor=self.requestor, device=self.config.device, ) - def _create_tables(self): - # Create vec0 virtual table for thumbnail embeddings - self.db.execute_sql(""" - CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( - id TEXT PRIMARY KEY, - thumbnail_embedding FLOAT[768] - ); - """) - - # Create vec0 virtual table for description embeddings - self.db.execute_sql(""" - CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( - id TEXT PRIMARY KEY, - description_embedding FLOAT[768] - ); - """) - - def _drop_tables(self): - self.db.execute_sql(""" - DROP TABLE vec_descriptions; - """) - self.db.execute_sql(""" - DROP TABLE vec_thumbnails; - """) - def upsert_thumbnail(self, event_id: str, thumbnail: bytes): # Convert thumbnail bytes to PIL Image image = Image.open(io.BytesIO(thumbnail)).convert("RGB") @@ -153,7 +130,6 @@ class Embeddings: def upsert_description(self, event_id: str, description: str): embedding = self.text_embedding([description])[0] - self.db.execute_sql( """ INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) @@ -167,8 +143,10 @@ class Embeddings: def reindex(self) -> None: logger.info("Indexing tracked object embeddings...") - self._drop_tables() - self._create_tables() + self.db.drop_embeddings_tables() + logger.debug("Dropped embeddings tables.") + self.db.create_embeddings_tables() + logger.debug("Created embeddings tables.") st = time.time() totals = { diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index 08901b6a2..34a81528a 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -15,6 +15,7 @@ from PIL import Image from transformers import AutoFeatureExtractor, AutoTokenizer from transformers.utils.logging import disable_progress_bar +from frigate.comms.inter_process import InterProcessRequestor from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.types import ModelStatusTypesEnum from frigate.util.downloader import ModelDownloader @@ -41,12 +42,14 @@ class GenericONNXEmbedding: download_urls: Dict[str, str], embedding_function: Callable[[List[np.ndarray]], np.ndarray], model_type: str, + requestor: InterProcessRequestor, tokenizer_file: Optional[str] = None, device: str = "AUTO", ): self.model_name = model_name self.model_file = model_file self.tokenizer_file = tokenizer_file + self.requestor = requestor self.download_urls = download_urls self.embedding_function = embedding_function self.model_type = model_type # 'text' or 'vision' @@ -58,15 +61,32 @@ class GenericONNXEmbedding: self.tokenizer = None self.feature_extractor = None self.session = None - - self.downloader = ModelDownloader( - model_name=self.model_name, - download_path=self.download_path, - file_names=list(self.download_urls.keys()) - + ([self.tokenizer_file] if self.tokenizer_file else []), - download_func=self._download_model, + files_names = list(self.download_urls.keys()) + ( + [self.tokenizer_file] if self.tokenizer_file else [] ) - self.downloader.ensure_model_files() + + if not all( + os.path.exists(os.path.join(self.download_path, n)) for n in files_names + ): + logger.debug(f"starting model download for {self.model_name}") + self.downloader = ModelDownloader( + model_name=self.model_name, + download_path=self.download_path, + file_names=files_names, + requestor=self.requestor, + download_func=self._download_model, + ) + self.downloader.ensure_model_files() + else: + self.downloader = None + ModelDownloader.mark_files_state( + self.requestor, + self.model_name, + files_names, + ModelStatusTypesEnum.downloaded, + ) + self._load_model_and_tokenizer() + logger.debug(f"models are already downloaded for {self.model_name}") def _download_model(self, path: str): try: @@ -102,7 +122,8 @@ class GenericONNXEmbedding: def _load_model_and_tokenizer(self): if self.session is None: - self.downloader.wait_for_download() + if self.downloader: + self.downloader.wait_for_download() if self.model_type == "text": self.tokenizer = self._load_tokenizer() else: @@ -125,13 +146,12 @@ class GenericONNXEmbedding: f"{MODEL_CACHE_DIR}/{self.model_name}", ) - def _load_model(self, path: str): + def _load_model(self, path: str) -> Optional[ort.InferenceSession]: if os.path.exists(path): return ort.InferenceSession( path, providers=self.providers, provider_options=self.provider_options ) else: - logger.warning(f"{self.model_name} model file {path} not found.") return None def _process_image(self, image): diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 68c3e3686..238efcfdf 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -41,10 +41,14 @@ class EmbeddingMaintainer(threading.Thread): config: FrigateConfig, stop_event: MpEvent, ) -> None: - threading.Thread.__init__(self) - self.name = "embeddings_maintainer" + super().__init__(name="embeddings_maintainer") self.config = config self.embeddings = Embeddings(config.semantic_search, db) + + # Check if we need to re-index events + if config.semantic_search.reindex: + self.embeddings.reindex() + self.event_subscriber = EventUpdateSubscriber() self.event_end_subscriber = EventEndSubscriber() self.event_metadata_subscriber = EventMetadataSubscriber( @@ -76,26 +80,33 @@ class EmbeddingMaintainer(threading.Thread): def _process_requests(self) -> None: """Process embeddings requests""" - def handle_request(topic: str, data: str) -> str: - if topic == EmbeddingsRequestEnum.embed_description.value: - return serialize( - self.embeddings.upsert_description(data["id"], data["description"]), - pack=False, - ) - elif topic == EmbeddingsRequestEnum.embed_thumbnail.value: - thumbnail = base64.b64decode(data["thumbnail"]) - return serialize( - self.embeddings.upsert_thumbnail(data["id"], thumbnail), - pack=False, - ) - elif topic == EmbeddingsRequestEnum.generate_search.value: - return serialize(self.embeddings.text_embedding([data])[0], pack=False) + def _handle_request(topic: str, data: str) -> str: + try: + if topic == EmbeddingsRequestEnum.embed_description.value: + return serialize( + self.embeddings.upsert_description( + data["id"], data["description"] + ), + pack=False, + ) + elif topic == EmbeddingsRequestEnum.embed_thumbnail.value: + thumbnail = base64.b64decode(data["thumbnail"]) + return serialize( + self.embeddings.upsert_thumbnail(data["id"], thumbnail), + pack=False, + ) + elif topic == EmbeddingsRequestEnum.generate_search.value: + return serialize( + self.embeddings.text_embedding([data])[0], pack=False + ) + except Exception as e: + logger.error(f"Unable to handle embeddings request {e}") - self.embeddings_responder.check_for_request(handle_request) + self.embeddings_responder.check_for_request(_handle_request) def _process_updates(self) -> None: """Process event updates""" - update = self.event_subscriber.check_for_update() + update = self.event_subscriber.check_for_update(timeout=0.1) if update is None: return @@ -124,7 +135,7 @@ class EmbeddingMaintainer(threading.Thread): def _process_finalized(self) -> None: """Process the end of an event.""" while True: - ended = self.event_end_subscriber.check_for_update() + ended = self.event_end_subscriber.check_for_update(timeout=0.1) if ended == None: break @@ -161,9 +172,6 @@ class EmbeddingMaintainer(threading.Thread): or set(event.zones) & set(camera_config.genai.required_zones) ) ): - logger.debug( - f"Description generation for {event}, has_snapshot: {event.has_snapshot}" - ) if event.has_snapshot and camera_config.genai.use_snapshot: with open( os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"), @@ -217,7 +225,7 @@ class EmbeddingMaintainer(threading.Thread): def _process_event_metadata(self): # Check for regenerate description requests (topic, event_id, source) = self.event_metadata_subscriber.check_for_update( - timeout=1 + timeout=0.1 ) if topic is None: diff --git a/frigate/events/cleanup.py b/frigate/events/cleanup.py index 828b295b4..8fabf2b21 100644 --- a/frigate/events/cleanup.py +++ b/frigate/events/cleanup.py @@ -8,11 +8,9 @@ from enum import Enum from multiprocessing.synchronize import Event as MpEvent from pathlib import Path -from playhouse.sqliteq import SqliteQueueDatabase - from frigate.config import FrigateConfig from frigate.const import CLIPS_DIR -from frigate.embeddings.embeddings import Embeddings +from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.models import Event, Timeline logger = logging.getLogger(__name__) @@ -25,7 +23,7 @@ class EventCleanupType(str, Enum): class EventCleanup(threading.Thread): def __init__( - self, config: FrigateConfig, stop_event: MpEvent, db: SqliteQueueDatabase + self, config: FrigateConfig, stop_event: MpEvent, db: SqliteVecQueueDatabase ): super().__init__(name="event_cleanup") self.config = config @@ -35,9 +33,6 @@ class EventCleanup(threading.Thread): self.removed_camera_labels: list[str] = None self.camera_labels: dict[str, dict[str, any]] = {} - if self.config.semantic_search.enabled: - self.embeddings = Embeddings(self.config.semantic_search, self.db) - def get_removed_camera_labels(self) -> list[Event]: """Get a list of distinct labels for removed cameras.""" if self.removed_camera_labels is None: @@ -234,8 +229,8 @@ class EventCleanup(threading.Thread): Event.delete().where(Event.id << chunk).execute() if self.config.semantic_search.enabled: - self.embeddings.delete_description(chunk) - self.embeddings.delete_thumbnail(chunk) + self.db.delete_embeddings_description(chunk) + self.db.delete_embeddings_thumbnail(chunk) logger.debug(f"Deleted {len(events_to_delete)} embeddings") logger.info("Exiting event cleanup...") diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py index 642dc7c8f..ce5030566 100644 --- a/frigate/util/downloader.py +++ b/frigate/util/downloader.py @@ -44,6 +44,7 @@ class ModelDownloader: download_path: str, file_names: List[str], download_func: Callable[[str], None], + requestor: InterProcessRequestor, silent: bool = False, ): self.model_name = model_name @@ -51,19 +52,17 @@ class ModelDownloader: self.file_names = file_names self.download_func = download_func self.silent = silent - self.requestor = InterProcessRequestor() + self.requestor = requestor self.download_thread = None self.download_complete = threading.Event() def ensure_model_files(self): - for file in self.file_names: - self.requestor.send_data( - UPDATE_MODEL_STATE, - { - "model": f"{self.model_name}-{file}", - "state": ModelStatusTypesEnum.downloading, - }, - ) + self.mark_files_state( + self.requestor, + self.model_name, + self.file_names, + ModelStatusTypesEnum.downloading, + ) self.download_thread = threading.Thread( target=self._download_models, name=f"_download_model_{self.model_name}", @@ -119,5 +118,21 @@ class ModelDownloader: if not silent: logger.info(f"Downloading complete: {url}") + @staticmethod + def mark_files_state( + requestor: InterProcessRequestor, + model_name: str, + files: list[str], + state: ModelStatusTypesEnum, + ) -> None: + for file_name in files: + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{model_name}-{file_name}", + "state": state, + }, + ) + def wait_for_download(self): self.download_complete.wait()