Embeddings fixes (#14269)

* Add debugging logs for more info

* Improve timeout handling

* Fix event cleanup

* Handle zmq error and empty data

* Don't run download

* Remove unneeded embeddings creations

* Update timouts

* Init models immediately

* Fix order of init

* Cleanup
This commit is contained in:
Nicolas Mowen 2024-10-10 15:37:43 -06:00 committed by GitHub
parent f67ec241d4
commit dd6276e706
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 154 additions and 111 deletions

View File

@ -581,12 +581,12 @@ class FrigateApp:
self.init_recording_manager() self.init_recording_manager()
self.init_review_segment_manager() self.init_review_segment_manager()
self.init_go2rtc() self.init_go2rtc()
self.start_detectors()
self.init_embeddings_manager()
self.bind_database() self.bind_database()
self.check_db_data_migrations() self.check_db_data_migrations()
self.init_inter_process_communicator() self.init_inter_process_communicator()
self.init_dispatcher() self.init_dispatcher()
self.start_detectors()
self.init_embeddings_manager()
self.init_embeddings_client() self.init_embeddings_client()
self.start_video_output_processor() self.start_video_output_processor()
self.start_ptz_autotracker() self.start_ptz_autotracker()

View File

@ -64,6 +64,9 @@ class Dispatcher:
self.onvif = onvif self.onvif = onvif
self.ptz_metrics = ptz_metrics self.ptz_metrics = ptz_metrics
self.comms = communicators self.comms = communicators
self.camera_activity = {}
self.model_state = {}
self.embeddings_reindex = {}
self._camera_settings_handlers: dict[str, Callable] = { self._camera_settings_handlers: dict[str, Callable] = {
"audio": self._on_audio_command, "audio": self._on_audio_command,
@ -85,10 +88,6 @@ class Dispatcher:
for comm in self.comms: for comm in self.comms:
comm.subscribe(self._receive) comm.subscribe(self._receive)
self.camera_activity = {}
self.model_state = {}
self.embeddings_reindex = {}
def _receive(self, topic: str, payload: str) -> Optional[Any]: def _receive(self, topic: str, payload: str) -> Optional[Any]:
"""Handle receiving of payload from communicators.""" """Handle receiving of payload from communicators."""

View File

@ -22,7 +22,7 @@ class EmbeddingsResponder:
def check_for_request(self, process: Callable) -> None: def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued while True: # load all messages that are queued
has_message, _, _ = zmq.select([self.socket], [], [], 1) has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
if not has_message: if not has_message:
break break
@ -54,8 +54,11 @@ class EmbeddingsRequestor:
def send_data(self, topic: str, data: any) -> str: def send_data(self, topic: str, data: any) -> str:
"""Sends data and then waits for reply.""" """Sends data and then waits for reply."""
self.socket.send_json((topic, data)) try:
return self.socket.recv_json() self.socket.send_json((topic, data))
return self.socket.recv_json()
except zmq.ZMQError:
return ""
def stop(self) -> None: def stop(self) -> None:
self.socket.close() self.socket.close()

View File

@ -39,7 +39,7 @@ class EventMetadataSubscriber(Subscriber):
super().__init__(topic) super().__init__(topic)
def check_for_update( def check_for_update(
self, timeout: float = None self, timeout: float = 1
) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]: ) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]:
return super().check_for_update(timeout) return super().check_for_update(timeout)

View File

@ -28,3 +28,26 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
def delete_embeddings_description(self, event_ids: list[str]) -> None: def delete_embeddings_description(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids]) ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids) self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
def drop_embeddings_tables(self) -> None:
self.execute_sql("""
DROP TABLE vec_descriptions;
""")
self.execute_sql("""
DROP TABLE vec_thumbnails;
""")
def create_embeddings_tables(self) -> None:
"""Create vec0 virtual table for embeddings"""
self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[768]
);
""")
self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[768]
);
""")

View File

@ -19,7 +19,6 @@ from frigate.models import Event
from frigate.util.builtin import serialize from frigate.util.builtin import serialize
from frigate.util.services import listen from frigate.util.services import listen
from .embeddings import Embeddings
from .maintainer import EmbeddingMaintainer from .maintainer import EmbeddingMaintainer
from .util import ZScoreNormalization from .util import ZScoreNormalization
@ -57,12 +56,6 @@ def manage_embeddings(config: FrigateConfig) -> None:
models = [Event] models = [Event]
db.bind(models) db.bind(models)
embeddings = Embeddings(config.semantic_search, db)
# Check if we need to re-index events
if config.semantic_search.reindex:
embeddings.reindex()
maintainer = EmbeddingMaintainer( maintainer = EmbeddingMaintainer(
db, db,
config, config,
@ -114,19 +107,25 @@ class EmbeddingsContext:
query_embedding = row[0] query_embedding = row[0]
else: else:
# If no embedding found, generate it and return it # If no embedding found, generate it and return it
query_embedding = serialize( data = self.requestor.send_data(
self.requestor.send_data( EmbeddingsRequestEnum.embed_thumbnail.value,
EmbeddingsRequestEnum.embed_thumbnail.value, {"id": str(query.id), "thumbnail": str(query.thumbnail)},
{"id": query.id, "thumbnail": query.thumbnail},
)
) )
if not data:
return []
query_embedding = serialize(data)
else: else:
query_embedding = serialize( data = self.requestor.send_data(
self.requestor.send_data( EmbeddingsRequestEnum.generate_search.value, query
EmbeddingsRequestEnum.generate_search.value, query
)
) )
if not data:
return []
query_embedding = serialize(data)
sql_query = """ sql_query = """
SELECT SELECT
id, id,
@ -155,12 +154,15 @@ class EmbeddingsContext:
def search_description( def search_description(
self, query_text: str, event_ids: list[str] = None self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]: ) -> list[tuple[str, float]]:
query_embedding = serialize( data = self.requestor.send_data(
self.requestor.send_data( EmbeddingsRequestEnum.generate_search.value, query_text
EmbeddingsRequestEnum.generate_search.value, query_text
)
) )
if not data:
return []
query_embedding = serialize(data)
# Prepare the base SQL query # Prepare the base SQL query
sql_query = """ sql_query = """
SELECT SELECT

View File

@ -63,7 +63,7 @@ class Embeddings:
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
# Create tables if they don't exist # Create tables if they don't exist
self._create_tables() self.db.create_embeddings_tables()
models = [ models = [
"jinaai/jina-clip-v1-text_model_fp16.onnx", "jinaai/jina-clip-v1-text_model_fp16.onnx",
@ -96,6 +96,7 @@ class Embeddings:
}, },
embedding_function=jina_text_embedding_function, embedding_function=jina_text_embedding_function,
model_type="text", model_type="text",
requestor=self.requestor,
device="CPU", device="CPU",
) )
@ -108,34 +109,10 @@ class Embeddings:
}, },
embedding_function=jina_vision_embedding_function, embedding_function=jina_vision_embedding_function,
model_type="vision", model_type="vision",
requestor=self.requestor,
device=self.config.device, device=self.config.device,
) )
def _create_tables(self):
# Create vec0 virtual table for thumbnail embeddings
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[768]
);
""")
# Create vec0 virtual table for description embeddings
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[768]
);
""")
def _drop_tables(self):
self.db.execute_sql("""
DROP TABLE vec_descriptions;
""")
self.db.execute_sql("""
DROP TABLE vec_thumbnails;
""")
def upsert_thumbnail(self, event_id: str, thumbnail: bytes): def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image # Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB") image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
@ -153,7 +130,6 @@ class Embeddings:
def upsert_description(self, event_id: str, description: str): def upsert_description(self, event_id: str, description: str):
embedding = self.text_embedding([description])[0] embedding = self.text_embedding([description])[0]
self.db.execute_sql( self.db.execute_sql(
""" """
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
@ -167,8 +143,10 @@ class Embeddings:
def reindex(self) -> None: def reindex(self) -> None:
logger.info("Indexing tracked object embeddings...") logger.info("Indexing tracked object embeddings...")
self._drop_tables() self.db.drop_embeddings_tables()
self._create_tables() logger.debug("Dropped embeddings tables.")
self.db.create_embeddings_tables()
logger.debug("Created embeddings tables.")
st = time.time() st = time.time()
totals = { totals = {

View File

@ -15,6 +15,7 @@ from PIL import Image
from transformers import AutoFeatureExtractor, AutoTokenizer from transformers import AutoFeatureExtractor, AutoTokenizer
from transformers.utils.logging import disable_progress_bar from transformers.utils.logging import disable_progress_bar
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader from frigate.util.downloader import ModelDownloader
@ -41,12 +42,14 @@ class GenericONNXEmbedding:
download_urls: Dict[str, str], download_urls: Dict[str, str],
embedding_function: Callable[[List[np.ndarray]], np.ndarray], embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_type: str, model_type: str,
requestor: InterProcessRequestor,
tokenizer_file: Optional[str] = None, tokenizer_file: Optional[str] = None,
device: str = "AUTO", device: str = "AUTO",
): ):
self.model_name = model_name self.model_name = model_name
self.model_file = model_file self.model_file = model_file
self.tokenizer_file = tokenizer_file self.tokenizer_file = tokenizer_file
self.requestor = requestor
self.download_urls = download_urls self.download_urls = download_urls
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision' self.model_type = model_type # 'text' or 'vision'
@ -58,15 +61,32 @@ class GenericONNXEmbedding:
self.tokenizer = None self.tokenizer = None
self.feature_extractor = None self.feature_extractor = None
self.session = None self.session = None
files_names = list(self.download_urls.keys()) + (
self.downloader = ModelDownloader( [self.tokenizer_file] if self.tokenizer_file else []
model_name=self.model_name,
download_path=self.download_path,
file_names=list(self.download_urls.keys())
+ ([self.tokenizer_file] if self.tokenizer_file else []),
download_func=self._download_model,
) )
self.downloader.ensure_model_files()
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
):
logger.debug(f"starting model download for {self.model_name}")
self.downloader = ModelDownloader(
model_name=self.model_name,
download_path=self.download_path,
file_names=files_names,
requestor=self.requestor,
download_func=self._download_model,
)
self.downloader.ensure_model_files()
else:
self.downloader = None
ModelDownloader.mark_files_state(
self.requestor,
self.model_name,
files_names,
ModelStatusTypesEnum.downloaded,
)
self._load_model_and_tokenizer()
logger.debug(f"models are already downloaded for {self.model_name}")
def _download_model(self, path: str): def _download_model(self, path: str):
try: try:
@ -102,7 +122,8 @@ class GenericONNXEmbedding:
def _load_model_and_tokenizer(self): def _load_model_and_tokenizer(self):
if self.session is None: if self.session is None:
self.downloader.wait_for_download() if self.downloader:
self.downloader.wait_for_download()
if self.model_type == "text": if self.model_type == "text":
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
else: else:
@ -125,13 +146,12 @@ class GenericONNXEmbedding:
f"{MODEL_CACHE_DIR}/{self.model_name}", f"{MODEL_CACHE_DIR}/{self.model_name}",
) )
def _load_model(self, path: str): def _load_model(self, path: str) -> Optional[ort.InferenceSession]:
if os.path.exists(path): if os.path.exists(path):
return ort.InferenceSession( return ort.InferenceSession(
path, providers=self.providers, provider_options=self.provider_options path, providers=self.providers, provider_options=self.provider_options
) )
else: else:
logger.warning(f"{self.model_name} model file {path} not found.")
return None return None
def _process_image(self, image): def _process_image(self, image):

View File

@ -41,10 +41,14 @@ class EmbeddingMaintainer(threading.Thread):
config: FrigateConfig, config: FrigateConfig,
stop_event: MpEvent, stop_event: MpEvent,
) -> None: ) -> None:
threading.Thread.__init__(self) super().__init__(name="embeddings_maintainer")
self.name = "embeddings_maintainer"
self.config = config self.config = config
self.embeddings = Embeddings(config.semantic_search, db) self.embeddings = Embeddings(config.semantic_search, db)
# Check if we need to re-index events
if config.semantic_search.reindex:
self.embeddings.reindex()
self.event_subscriber = EventUpdateSubscriber() self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber() self.event_end_subscriber = EventEndSubscriber()
self.event_metadata_subscriber = EventMetadataSubscriber( self.event_metadata_subscriber = EventMetadataSubscriber(
@ -76,26 +80,33 @@ class EmbeddingMaintainer(threading.Thread):
def _process_requests(self) -> None: def _process_requests(self) -> None:
"""Process embeddings requests""" """Process embeddings requests"""
def handle_request(topic: str, data: str) -> str: def _handle_request(topic: str, data: str) -> str:
if topic == EmbeddingsRequestEnum.embed_description.value: try:
return serialize( if topic == EmbeddingsRequestEnum.embed_description.value:
self.embeddings.upsert_description(data["id"], data["description"]), return serialize(
pack=False, self.embeddings.upsert_description(
) data["id"], data["description"]
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value: ),
thumbnail = base64.b64decode(data["thumbnail"]) pack=False,
return serialize( )
self.embeddings.upsert_thumbnail(data["id"], thumbnail), elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
pack=False, thumbnail = base64.b64decode(data["thumbnail"])
) return serialize(
elif topic == EmbeddingsRequestEnum.generate_search.value: self.embeddings.upsert_thumbnail(data["id"], thumbnail),
return serialize(self.embeddings.text_embedding([data])[0], pack=False) pack=False,
)
elif topic == EmbeddingsRequestEnum.generate_search.value:
return serialize(
self.embeddings.text_embedding([data])[0], pack=False
)
except Exception as e:
logger.error(f"Unable to handle embeddings request {e}")
self.embeddings_responder.check_for_request(handle_request) self.embeddings_responder.check_for_request(_handle_request)
def _process_updates(self) -> None: def _process_updates(self) -> None:
"""Process event updates""" """Process event updates"""
update = self.event_subscriber.check_for_update() update = self.event_subscriber.check_for_update(timeout=0.1)
if update is None: if update is None:
return return
@ -124,7 +135,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_finalized(self) -> None: def _process_finalized(self) -> None:
"""Process the end of an event.""" """Process the end of an event."""
while True: while True:
ended = self.event_end_subscriber.check_for_update() ended = self.event_end_subscriber.check_for_update(timeout=0.1)
if ended == None: if ended == None:
break break
@ -161,9 +172,6 @@ class EmbeddingMaintainer(threading.Thread):
or set(event.zones) & set(camera_config.genai.required_zones) or set(event.zones) & set(camera_config.genai.required_zones)
) )
): ):
logger.debug(
f"Description generation for {event}, has_snapshot: {event.has_snapshot}"
)
if event.has_snapshot and camera_config.genai.use_snapshot: if event.has_snapshot and camera_config.genai.use_snapshot:
with open( with open(
os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"), os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"),
@ -217,7 +225,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_event_metadata(self): def _process_event_metadata(self):
# Check for regenerate description requests # Check for regenerate description requests
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update( (topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
timeout=1 timeout=0.1
) )
if topic is None: if topic is None:

View File

@ -8,11 +8,9 @@ from enum import Enum
from multiprocessing.synchronize import Event as MpEvent from multiprocessing.synchronize import Event as MpEvent
from pathlib import Path from pathlib import Path
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CLIPS_DIR from frigate.const import CLIPS_DIR
from frigate.embeddings.embeddings import Embeddings from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event, Timeline from frigate.models import Event, Timeline
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,7 +23,7 @@ class EventCleanupType(str, Enum):
class EventCleanup(threading.Thread): class EventCleanup(threading.Thread):
def __init__( def __init__(
self, config: FrigateConfig, stop_event: MpEvent, db: SqliteQueueDatabase self, config: FrigateConfig, stop_event: MpEvent, db: SqliteVecQueueDatabase
): ):
super().__init__(name="event_cleanup") super().__init__(name="event_cleanup")
self.config = config self.config = config
@ -35,9 +33,6 @@ class EventCleanup(threading.Thread):
self.removed_camera_labels: list[str] = None self.removed_camera_labels: list[str] = None
self.camera_labels: dict[str, dict[str, any]] = {} self.camera_labels: dict[str, dict[str, any]] = {}
if self.config.semantic_search.enabled:
self.embeddings = Embeddings(self.config.semantic_search, self.db)
def get_removed_camera_labels(self) -> list[Event]: def get_removed_camera_labels(self) -> list[Event]:
"""Get a list of distinct labels for removed cameras.""" """Get a list of distinct labels for removed cameras."""
if self.removed_camera_labels is None: if self.removed_camera_labels is None:
@ -234,8 +229,8 @@ class EventCleanup(threading.Thread):
Event.delete().where(Event.id << chunk).execute() Event.delete().where(Event.id << chunk).execute()
if self.config.semantic_search.enabled: if self.config.semantic_search.enabled:
self.embeddings.delete_description(chunk) self.db.delete_embeddings_description(chunk)
self.embeddings.delete_thumbnail(chunk) self.db.delete_embeddings_thumbnail(chunk)
logger.debug(f"Deleted {len(events_to_delete)} embeddings") logger.debug(f"Deleted {len(events_to_delete)} embeddings")
logger.info("Exiting event cleanup...") logger.info("Exiting event cleanup...")

View File

@ -44,6 +44,7 @@ class ModelDownloader:
download_path: str, download_path: str,
file_names: List[str], file_names: List[str],
download_func: Callable[[str], None], download_func: Callable[[str], None],
requestor: InterProcessRequestor,
silent: bool = False, silent: bool = False,
): ):
self.model_name = model_name self.model_name = model_name
@ -51,19 +52,17 @@ class ModelDownloader:
self.file_names = file_names self.file_names = file_names
self.download_func = download_func self.download_func = download_func
self.silent = silent self.silent = silent
self.requestor = InterProcessRequestor() self.requestor = requestor
self.download_thread = None self.download_thread = None
self.download_complete = threading.Event() self.download_complete = threading.Event()
def ensure_model_files(self): def ensure_model_files(self):
for file in self.file_names: self.mark_files_state(
self.requestor.send_data( self.requestor,
UPDATE_MODEL_STATE, self.model_name,
{ self.file_names,
"model": f"{self.model_name}-{file}", ModelStatusTypesEnum.downloading,
"state": ModelStatusTypesEnum.downloading, )
},
)
self.download_thread = threading.Thread( self.download_thread = threading.Thread(
target=self._download_models, target=self._download_models,
name=f"_download_model_{self.model_name}", name=f"_download_model_{self.model_name}",
@ -119,5 +118,21 @@ class ModelDownloader:
if not silent: if not silent:
logger.info(f"Downloading complete: {url}") logger.info(f"Downloading complete: {url}")
@staticmethod
def mark_files_state(
requestor: InterProcessRequestor,
model_name: str,
files: list[str],
state: ModelStatusTypesEnum,
) -> None:
for file_name in files:
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{model_name}-{file_name}",
"state": state,
},
)
def wait_for_download(self): def wait_for_download(self):
self.download_complete.wait() self.download_complete.wait()