mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Embeddings fixes (#14269)
* Add debugging logs for more info * Improve timeout handling * Fix event cleanup * Handle zmq error and empty data * Don't run download * Remove unneeded embeddings creations * Update timouts * Init models immediately * Fix order of init * Cleanup
This commit is contained in:
parent
f67ec241d4
commit
dd6276e706
@ -581,12 +581,12 @@ class FrigateApp:
|
||||
self.init_recording_manager()
|
||||
self.init_review_segment_manager()
|
||||
self.init_go2rtc()
|
||||
self.start_detectors()
|
||||
self.init_embeddings_manager()
|
||||
self.bind_database()
|
||||
self.check_db_data_migrations()
|
||||
self.init_inter_process_communicator()
|
||||
self.init_dispatcher()
|
||||
self.start_detectors()
|
||||
self.init_embeddings_manager()
|
||||
self.init_embeddings_client()
|
||||
self.start_video_output_processor()
|
||||
self.start_ptz_autotracker()
|
||||
|
@ -64,6 +64,9 @@ class Dispatcher:
|
||||
self.onvif = onvif
|
||||
self.ptz_metrics = ptz_metrics
|
||||
self.comms = communicators
|
||||
self.camera_activity = {}
|
||||
self.model_state = {}
|
||||
self.embeddings_reindex = {}
|
||||
|
||||
self._camera_settings_handlers: dict[str, Callable] = {
|
||||
"audio": self._on_audio_command,
|
||||
@ -85,10 +88,6 @@ class Dispatcher:
|
||||
for comm in self.comms:
|
||||
comm.subscribe(self._receive)
|
||||
|
||||
self.camera_activity = {}
|
||||
self.model_state = {}
|
||||
self.embeddings_reindex = {}
|
||||
|
||||
def _receive(self, topic: str, payload: str) -> Optional[Any]:
|
||||
"""Handle receiving of payload from communicators."""
|
||||
|
||||
|
@ -22,7 +22,7 @@ class EmbeddingsResponder:
|
||||
|
||||
def check_for_request(self, process: Callable) -> None:
|
||||
while True: # load all messages that are queued
|
||||
has_message, _, _ = zmq.select([self.socket], [], [], 1)
|
||||
has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
|
||||
|
||||
if not has_message:
|
||||
break
|
||||
@ -54,8 +54,11 @@ class EmbeddingsRequestor:
|
||||
|
||||
def send_data(self, topic: str, data: any) -> str:
|
||||
"""Sends data and then waits for reply."""
|
||||
self.socket.send_json((topic, data))
|
||||
return self.socket.recv_json()
|
||||
try:
|
||||
self.socket.send_json((topic, data))
|
||||
return self.socket.recv_json()
|
||||
except zmq.ZMQError:
|
||||
return ""
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
|
@ -39,7 +39,7 @@ class EventMetadataSubscriber(Subscriber):
|
||||
super().__init__(topic)
|
||||
|
||||
def check_for_update(
|
||||
self, timeout: float = None
|
||||
self, timeout: float = 1
|
||||
) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]:
|
||||
return super().check_for_update(timeout)
|
||||
|
||||
|
@ -28,3 +28,26 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
||||
def delete_embeddings_description(self, event_ids: list[str]) -> None:
|
||||
ids = ",".join(["?" for _ in event_ids])
|
||||
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
|
||||
|
||||
def drop_embeddings_tables(self) -> None:
|
||||
self.execute_sql("""
|
||||
DROP TABLE vec_descriptions;
|
||||
""")
|
||||
self.execute_sql("""
|
||||
DROP TABLE vec_thumbnails;
|
||||
""")
|
||||
|
||||
def create_embeddings_tables(self) -> None:
|
||||
"""Create vec0 virtual table for embeddings"""
|
||||
self.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
thumbnail_embedding FLOAT[768]
|
||||
);
|
||||
""")
|
||||
self.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
description_embedding FLOAT[768]
|
||||
);
|
||||
""")
|
||||
|
@ -19,7 +19,6 @@ from frigate.models import Event
|
||||
from frigate.util.builtin import serialize
|
||||
from frigate.util.services import listen
|
||||
|
||||
from .embeddings import Embeddings
|
||||
from .maintainer import EmbeddingMaintainer
|
||||
from .util import ZScoreNormalization
|
||||
|
||||
@ -57,12 +56,6 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
||||
models = [Event]
|
||||
db.bind(models)
|
||||
|
||||
embeddings = Embeddings(config.semantic_search, db)
|
||||
|
||||
# Check if we need to re-index events
|
||||
if config.semantic_search.reindex:
|
||||
embeddings.reindex()
|
||||
|
||||
maintainer = EmbeddingMaintainer(
|
||||
db,
|
||||
config,
|
||||
@ -114,19 +107,25 @@ class EmbeddingsContext:
|
||||
query_embedding = row[0]
|
||||
else:
|
||||
# If no embedding found, generate it and return it
|
||||
query_embedding = serialize(
|
||||
self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.embed_thumbnail.value,
|
||||
{"id": query.id, "thumbnail": query.thumbnail},
|
||||
)
|
||||
data = self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.embed_thumbnail.value,
|
||||
{"id": str(query.id), "thumbnail": str(query.thumbnail)},
|
||||
)
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
query_embedding = serialize(data)
|
||||
else:
|
||||
query_embedding = serialize(
|
||||
self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.generate_search.value, query
|
||||
)
|
||||
data = self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.generate_search.value, query
|
||||
)
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
query_embedding = serialize(data)
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
id,
|
||||
@ -155,12 +154,15 @@ class EmbeddingsContext:
|
||||
def search_description(
|
||||
self, query_text: str, event_ids: list[str] = None
|
||||
) -> list[tuple[str, float]]:
|
||||
query_embedding = serialize(
|
||||
self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.generate_search.value, query_text
|
||||
)
|
||||
data = self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.generate_search.value, query_text
|
||||
)
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
query_embedding = serialize(data)
|
||||
|
||||
# Prepare the base SQL query
|
||||
sql_query = """
|
||||
SELECT
|
||||
|
@ -63,7 +63,7 @@ class Embeddings:
|
||||
self.requestor = InterProcessRequestor()
|
||||
|
||||
# Create tables if they don't exist
|
||||
self._create_tables()
|
||||
self.db.create_embeddings_tables()
|
||||
|
||||
models = [
|
||||
"jinaai/jina-clip-v1-text_model_fp16.onnx",
|
||||
@ -96,6 +96,7 @@ class Embeddings:
|
||||
},
|
||||
embedding_function=jina_text_embedding_function,
|
||||
model_type="text",
|
||||
requestor=self.requestor,
|
||||
device="CPU",
|
||||
)
|
||||
|
||||
@ -108,34 +109,10 @@ class Embeddings:
|
||||
},
|
||||
embedding_function=jina_vision_embedding_function,
|
||||
model_type="vision",
|
||||
requestor=self.requestor,
|
||||
device=self.config.device,
|
||||
)
|
||||
|
||||
def _create_tables(self):
|
||||
# Create vec0 virtual table for thumbnail embeddings
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
thumbnail_embedding FLOAT[768]
|
||||
);
|
||||
""")
|
||||
|
||||
# Create vec0 virtual table for description embeddings
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
description_embedding FLOAT[768]
|
||||
);
|
||||
""")
|
||||
|
||||
def _drop_tables(self):
|
||||
self.db.execute_sql("""
|
||||
DROP TABLE vec_descriptions;
|
||||
""")
|
||||
self.db.execute_sql("""
|
||||
DROP TABLE vec_thumbnails;
|
||||
""")
|
||||
|
||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
|
||||
# Convert thumbnail bytes to PIL Image
|
||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||
@ -153,7 +130,6 @@ class Embeddings:
|
||||
|
||||
def upsert_description(self, event_id: str, description: str):
|
||||
embedding = self.text_embedding([description])[0]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||
@ -167,8 +143,10 @@ class Embeddings:
|
||||
def reindex(self) -> None:
|
||||
logger.info("Indexing tracked object embeddings...")
|
||||
|
||||
self._drop_tables()
|
||||
self._create_tables()
|
||||
self.db.drop_embeddings_tables()
|
||||
logger.debug("Dropped embeddings tables.")
|
||||
self.db.create_embeddings_tables()
|
||||
logger.debug("Created embeddings tables.")
|
||||
|
||||
st = time.time()
|
||||
totals = {
|
||||
|
@ -15,6 +15,7 @@ from PIL import Image
|
||||
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||
from transformers.utils.logging import disable_progress_bar
|
||||
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
@ -41,12 +42,14 @@ class GenericONNXEmbedding:
|
||||
download_urls: Dict[str, str],
|
||||
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
||||
model_type: str,
|
||||
requestor: InterProcessRequestor,
|
||||
tokenizer_file: Optional[str] = None,
|
||||
device: str = "AUTO",
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model_file = model_file
|
||||
self.tokenizer_file = tokenizer_file
|
||||
self.requestor = requestor
|
||||
self.download_urls = download_urls
|
||||
self.embedding_function = embedding_function
|
||||
self.model_type = model_type # 'text' or 'vision'
|
||||
@ -58,15 +61,32 @@ class GenericONNXEmbedding:
|
||||
self.tokenizer = None
|
||||
self.feature_extractor = None
|
||||
self.session = None
|
||||
|
||||
self.downloader = ModelDownloader(
|
||||
model_name=self.model_name,
|
||||
download_path=self.download_path,
|
||||
file_names=list(self.download_urls.keys())
|
||||
+ ([self.tokenizer_file] if self.tokenizer_file else []),
|
||||
download_func=self._download_model,
|
||||
files_names = list(self.download_urls.keys()) + (
|
||||
[self.tokenizer_file] if self.tokenizer_file else []
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
|
||||
if not all(
|
||||
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||
):
|
||||
logger.debug(f"starting model download for {self.model_name}")
|
||||
self.downloader = ModelDownloader(
|
||||
model_name=self.model_name,
|
||||
download_path=self.download_path,
|
||||
file_names=files_names,
|
||||
requestor=self.requestor,
|
||||
download_func=self._download_model,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
else:
|
||||
self.downloader = None
|
||||
ModelDownloader.mark_files_state(
|
||||
self.requestor,
|
||||
self.model_name,
|
||||
files_names,
|
||||
ModelStatusTypesEnum.downloaded,
|
||||
)
|
||||
self._load_model_and_tokenizer()
|
||||
logger.debug(f"models are already downloaded for {self.model_name}")
|
||||
|
||||
def _download_model(self, path: str):
|
||||
try:
|
||||
@ -102,7 +122,8 @@ class GenericONNXEmbedding:
|
||||
|
||||
def _load_model_and_tokenizer(self):
|
||||
if self.session is None:
|
||||
self.downloader.wait_for_download()
|
||||
if self.downloader:
|
||||
self.downloader.wait_for_download()
|
||||
if self.model_type == "text":
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
else:
|
||||
@ -125,13 +146,12 @@ class GenericONNXEmbedding:
|
||||
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
||||
)
|
||||
|
||||
def _load_model(self, path: str):
|
||||
def _load_model(self, path: str) -> Optional[ort.InferenceSession]:
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(
|
||||
path, providers=self.providers, provider_options=self.provider_options
|
||||
)
|
||||
else:
|
||||
logger.warning(f"{self.model_name} model file {path} not found.")
|
||||
return None
|
||||
|
||||
def _process_image(self, image):
|
||||
|
@ -41,10 +41,14 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
config: FrigateConfig,
|
||||
stop_event: MpEvent,
|
||||
) -> None:
|
||||
threading.Thread.__init__(self)
|
||||
self.name = "embeddings_maintainer"
|
||||
super().__init__(name="embeddings_maintainer")
|
||||
self.config = config
|
||||
self.embeddings = Embeddings(config.semantic_search, db)
|
||||
|
||||
# Check if we need to re-index events
|
||||
if config.semantic_search.reindex:
|
||||
self.embeddings.reindex()
|
||||
|
||||
self.event_subscriber = EventUpdateSubscriber()
|
||||
self.event_end_subscriber = EventEndSubscriber()
|
||||
self.event_metadata_subscriber = EventMetadataSubscriber(
|
||||
@ -76,26 +80,33 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
def _process_requests(self) -> None:
|
||||
"""Process embeddings requests"""
|
||||
|
||||
def handle_request(topic: str, data: str) -> str:
|
||||
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||
return serialize(
|
||||
self.embeddings.upsert_description(data["id"], data["description"]),
|
||||
pack=False,
|
||||
)
|
||||
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
|
||||
thumbnail = base64.b64decode(data["thumbnail"])
|
||||
return serialize(
|
||||
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
|
||||
pack=False,
|
||||
)
|
||||
elif topic == EmbeddingsRequestEnum.generate_search.value:
|
||||
return serialize(self.embeddings.text_embedding([data])[0], pack=False)
|
||||
def _handle_request(topic: str, data: str) -> str:
|
||||
try:
|
||||
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||
return serialize(
|
||||
self.embeddings.upsert_description(
|
||||
data["id"], data["description"]
|
||||
),
|
||||
pack=False,
|
||||
)
|
||||
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
|
||||
thumbnail = base64.b64decode(data["thumbnail"])
|
||||
return serialize(
|
||||
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
|
||||
pack=False,
|
||||
)
|
||||
elif topic == EmbeddingsRequestEnum.generate_search.value:
|
||||
return serialize(
|
||||
self.embeddings.text_embedding([data])[0], pack=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to handle embeddings request {e}")
|
||||
|
||||
self.embeddings_responder.check_for_request(handle_request)
|
||||
self.embeddings_responder.check_for_request(_handle_request)
|
||||
|
||||
def _process_updates(self) -> None:
|
||||
"""Process event updates"""
|
||||
update = self.event_subscriber.check_for_update()
|
||||
update = self.event_subscriber.check_for_update(timeout=0.1)
|
||||
|
||||
if update is None:
|
||||
return
|
||||
@ -124,7 +135,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
def _process_finalized(self) -> None:
|
||||
"""Process the end of an event."""
|
||||
while True:
|
||||
ended = self.event_end_subscriber.check_for_update()
|
||||
ended = self.event_end_subscriber.check_for_update(timeout=0.1)
|
||||
|
||||
if ended == None:
|
||||
break
|
||||
@ -161,9 +172,6 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
or set(event.zones) & set(camera_config.genai.required_zones)
|
||||
)
|
||||
):
|
||||
logger.debug(
|
||||
f"Description generation for {event}, has_snapshot: {event.has_snapshot}"
|
||||
)
|
||||
if event.has_snapshot and camera_config.genai.use_snapshot:
|
||||
with open(
|
||||
os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"),
|
||||
@ -217,7 +225,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
def _process_event_metadata(self):
|
||||
# Check for regenerate description requests
|
||||
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
|
||||
timeout=1
|
||||
timeout=0.1
|
||||
)
|
||||
|
||||
if topic is None:
|
||||
|
@ -8,11 +8,9 @@ from enum import Enum
|
||||
from multiprocessing.synchronize import Event as MpEvent
|
||||
from pathlib import Path
|
||||
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import CLIPS_DIR
|
||||
from frigate.embeddings.embeddings import Embeddings
|
||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event, Timeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -25,7 +23,7 @@ class EventCleanupType(str, Enum):
|
||||
|
||||
class EventCleanup(threading.Thread):
|
||||
def __init__(
|
||||
self, config: FrigateConfig, stop_event: MpEvent, db: SqliteQueueDatabase
|
||||
self, config: FrigateConfig, stop_event: MpEvent, db: SqliteVecQueueDatabase
|
||||
):
|
||||
super().__init__(name="event_cleanup")
|
||||
self.config = config
|
||||
@ -35,9 +33,6 @@ class EventCleanup(threading.Thread):
|
||||
self.removed_camera_labels: list[str] = None
|
||||
self.camera_labels: dict[str, dict[str, any]] = {}
|
||||
|
||||
if self.config.semantic_search.enabled:
|
||||
self.embeddings = Embeddings(self.config.semantic_search, self.db)
|
||||
|
||||
def get_removed_camera_labels(self) -> list[Event]:
|
||||
"""Get a list of distinct labels for removed cameras."""
|
||||
if self.removed_camera_labels is None:
|
||||
@ -234,8 +229,8 @@ class EventCleanup(threading.Thread):
|
||||
Event.delete().where(Event.id << chunk).execute()
|
||||
|
||||
if self.config.semantic_search.enabled:
|
||||
self.embeddings.delete_description(chunk)
|
||||
self.embeddings.delete_thumbnail(chunk)
|
||||
self.db.delete_embeddings_description(chunk)
|
||||
self.db.delete_embeddings_thumbnail(chunk)
|
||||
logger.debug(f"Deleted {len(events_to_delete)} embeddings")
|
||||
|
||||
logger.info("Exiting event cleanup...")
|
||||
|
@ -44,6 +44,7 @@ class ModelDownloader:
|
||||
download_path: str,
|
||||
file_names: List[str],
|
||||
download_func: Callable[[str], None],
|
||||
requestor: InterProcessRequestor,
|
||||
silent: bool = False,
|
||||
):
|
||||
self.model_name = model_name
|
||||
@ -51,19 +52,17 @@ class ModelDownloader:
|
||||
self.file_names = file_names
|
||||
self.download_func = download_func
|
||||
self.silent = silent
|
||||
self.requestor = InterProcessRequestor()
|
||||
self.requestor = requestor
|
||||
self.download_thread = None
|
||||
self.download_complete = threading.Event()
|
||||
|
||||
def ensure_model_files(self):
|
||||
for file in self.file_names:
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{file}",
|
||||
"state": ModelStatusTypesEnum.downloading,
|
||||
},
|
||||
)
|
||||
self.mark_files_state(
|
||||
self.requestor,
|
||||
self.model_name,
|
||||
self.file_names,
|
||||
ModelStatusTypesEnum.downloading,
|
||||
)
|
||||
self.download_thread = threading.Thread(
|
||||
target=self._download_models,
|
||||
name=f"_download_model_{self.model_name}",
|
||||
@ -119,5 +118,21 @@ class ModelDownloader:
|
||||
if not silent:
|
||||
logger.info(f"Downloading complete: {url}")
|
||||
|
||||
@staticmethod
|
||||
def mark_files_state(
|
||||
requestor: InterProcessRequestor,
|
||||
model_name: str,
|
||||
files: list[str],
|
||||
state: ModelStatusTypesEnum,
|
||||
) -> None:
|
||||
for file_name in files:
|
||||
requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{model_name}-{file_name}",
|
||||
"state": state,
|
||||
},
|
||||
)
|
||||
|
||||
def wait_for_download(self):
|
||||
self.download_complete.wait()
|
||||
|
Loading…
Reference in New Issue
Block a user