mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	Batch embeddings fixes (#14325)
* fixes * more readable loops * more robust key check and warning message * ensure we get reindex progress on mount * use correct var for length
This commit is contained in:
		
							parent
							
								
									66d0ad5803
								
							
						
					
					
						commit
						1ec459ea3a
					
				@ -145,13 +145,18 @@ class Embeddings:
 | 
			
		||||
        ]
 | 
			
		||||
        ids = list(event_thumbs.keys())
 | 
			
		||||
        embeddings = self.vision_embedding(images)
 | 
			
		||||
        items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
 | 
			
		||||
 | 
			
		||||
        items = []
 | 
			
		||||
 | 
			
		||||
        for i in range(len(ids)):
 | 
			
		||||
            items.append(ids[i])
 | 
			
		||||
            items.append(serialize(embeddings[i]))
 | 
			
		||||
 | 
			
		||||
        self.db.execute_sql(
 | 
			
		||||
            """
 | 
			
		||||
            INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
 | 
			
		||||
            VALUES {}
 | 
			
		||||
            """.format(", ".join(["(?, ?)"] * len(items))),
 | 
			
		||||
            """.format(", ".join(["(?, ?)"] * len(ids))),
 | 
			
		||||
            items,
 | 
			
		||||
        )
 | 
			
		||||
        return embeddings
 | 
			
		||||
@ -171,13 +176,18 @@ class Embeddings:
 | 
			
		||||
    def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
 | 
			
		||||
        embeddings = self.text_embedding(list(event_descriptions.values()))
 | 
			
		||||
        ids = list(event_descriptions.keys())
 | 
			
		||||
        items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
 | 
			
		||||
 | 
			
		||||
        items = []
 | 
			
		||||
 | 
			
		||||
        for i in range(len(ids)):
 | 
			
		||||
            items.append(ids[i])
 | 
			
		||||
            items.append(serialize(embeddings[i]))
 | 
			
		||||
 | 
			
		||||
        self.db.execute_sql(
 | 
			
		||||
            """
 | 
			
		||||
            INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
 | 
			
		||||
            VALUES {}
 | 
			
		||||
            """.format(", ".join(["(?, ?)"] * len(items))),
 | 
			
		||||
            """.format(", ".join(["(?, ?)"] * len(ids))),
 | 
			
		||||
            items,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -196,16 +206,6 @@ class Embeddings:
 | 
			
		||||
            os.remove(os.path.join(CONFIG_DIR, ".search_stats.json"))
 | 
			
		||||
 | 
			
		||||
        st = time.time()
 | 
			
		||||
        totals = {
 | 
			
		||||
            "thumbnails": 0,
 | 
			
		||||
            "descriptions": 0,
 | 
			
		||||
            "processed_objects": 0,
 | 
			
		||||
            "total_objects": 0,
 | 
			
		||||
            "time_remaining": 0,
 | 
			
		||||
            "status": "indexing",
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
 | 
			
		||||
 | 
			
		||||
        # Get total count of events to process
 | 
			
		||||
        total_events = (
 | 
			
		||||
@ -216,11 +216,21 @@ class Embeddings:
 | 
			
		||||
            )
 | 
			
		||||
            .count()
 | 
			
		||||
        )
 | 
			
		||||
        totals["total_objects"] = total_events
 | 
			
		||||
 | 
			
		||||
        batch_size = 32
 | 
			
		||||
        current_page = 1
 | 
			
		||||
 | 
			
		||||
        totals = {
 | 
			
		||||
            "thumbnails": 0,
 | 
			
		||||
            "descriptions": 0,
 | 
			
		||||
            "processed_objects": total_events - 1 if total_events < batch_size else 0,
 | 
			
		||||
            "total_objects": total_events,
 | 
			
		||||
            "time_remaining": 0 if total_events < batch_size else -1,
 | 
			
		||||
            "status": "indexing",
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self.requestor.send_data(UPDATE_EMBEDDINGS_REINDEX_PROGRESS, totals)
 | 
			
		||||
 | 
			
		||||
        events = (
 | 
			
		||||
            Event.select()
 | 
			
		||||
            .where(
 | 
			
		||||
 | 
			
		||||
@ -164,8 +164,15 @@ class GenericONNXEmbedding:
 | 
			
		||||
            return []
 | 
			
		||||
 | 
			
		||||
        if self.model_type == "text":
 | 
			
		||||
            max_length = max(len(self.tokenizer.encode(text)) for text in inputs)
 | 
			
		||||
            processed_inputs = [
 | 
			
		||||
                self.tokenizer(text, padding=True, truncation=True, return_tensors="np")
 | 
			
		||||
                self.tokenizer(
 | 
			
		||||
                    text,
 | 
			
		||||
                    padding="max_length",
 | 
			
		||||
                    truncation=True,
 | 
			
		||||
                    max_length=max_length,
 | 
			
		||||
                    return_tensors="np",
 | 
			
		||||
                )
 | 
			
		||||
                for text in inputs
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
@ -183,8 +190,11 @@ class GenericONNXEmbedding:
 | 
			
		||||
                if key in input_names:
 | 
			
		||||
                    onnx_inputs[key].append(value[0])
 | 
			
		||||
 | 
			
		||||
        for key in onnx_inputs.keys():
 | 
			
		||||
            onnx_inputs[key] = np.array(onnx_inputs[key])
 | 
			
		||||
        for key in input_names:
 | 
			
		||||
            if onnx_inputs.get(key):
 | 
			
		||||
                onnx_inputs[key] = np.stack(onnx_inputs[key])
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning(f"Expected input '{key}' not found in onnx_inputs")
 | 
			
		||||
 | 
			
		||||
        embeddings = self.runner.run(onnx_inputs)[0]
 | 
			
		||||
        return [embedding for embedding in embeddings]
 | 
			
		||||
 | 
			
		||||
@ -193,6 +193,17 @@ export default function Explore() {
 | 
			
		||||
 | 
			
		||||
  // embeddings reindex progress
 | 
			
		||||
 | 
			
		||||
  const { send: sendReindexCommand } = useWs(
 | 
			
		||||
    "embeddings_reindex_progress",
 | 
			
		||||
    "embeddingsReindexProgress",
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    sendReindexCommand("embeddingsReindexProgress");
 | 
			
		||||
    // only run on mount
 | 
			
		||||
    // eslint-disable-next-line react-hooks/exhaustive-deps
 | 
			
		||||
  }, []);
 | 
			
		||||
 | 
			
		||||
  const { payload: reindexProgress } = useEmbeddingsReindexProgress();
 | 
			
		||||
 | 
			
		||||
  const embeddingsReindexing = useMemo(() => {
 | 
			
		||||
@ -210,10 +221,10 @@ export default function Explore() {
 | 
			
		||||
 | 
			
		||||
  // model states
 | 
			
		||||
 | 
			
		||||
  const { send: sendCommand } = useWs("model_state", "modelState");
 | 
			
		||||
  const { send: sendModelCommand } = useWs("model_state", "modelState");
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    sendCommand("modelState");
 | 
			
		||||
    sendModelCommand("modelState");
 | 
			
		||||
    // only run on mount
 | 
			
		||||
    // eslint-disable-next-line react-hooks/exhaustive-deps
 | 
			
		||||
  }, []);
 | 
			
		||||
@ -299,14 +310,18 @@ export default function Explore() {
 | 
			
		||||
                  />
 | 
			
		||||
                </div>
 | 
			
		||||
                <div className="flex w-96 flex-col gap-2 py-5">
 | 
			
		||||
                  {reindexProgress.time_remaining >= 0 && (
 | 
			
		||||
                  {reindexProgress.time_remaining !== null && (
 | 
			
		||||
                    <div className="mb-3 flex flex-col items-center justify-center gap-1">
 | 
			
		||||
                      <div className="text-primary-variant">
 | 
			
		||||
                        Estimated time remaining:
 | 
			
		||||
                        {reindexProgress.time_remaining === -1
 | 
			
		||||
                          ? "Starting up..."
 | 
			
		||||
                          : "Estimated time remaining:"}
 | 
			
		||||
                      </div>
 | 
			
		||||
                      {formatSecondsToDuration(
 | 
			
		||||
                        reindexProgress.time_remaining,
 | 
			
		||||
                      ) || "Finishing shortly"}
 | 
			
		||||
                      {reindexProgress.time_remaining >= 0 &&
 | 
			
		||||
                        (formatSecondsToDuration(
 | 
			
		||||
                          reindexProgress.time_remaining,
 | 
			
		||||
                        ) ||
 | 
			
		||||
                          "Finishing shortly")}
 | 
			
		||||
                    </div>
 | 
			
		||||
                  )}
 | 
			
		||||
                  <div className="flex flex-row items-center justify-center gap-3">
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user