Update device to be an override, keep core logic in place if value of device is None.

This commit is contained in:
baudneo 2025-08-01 16:18:27 -06:00
parent d00a58c7f2
commit 1667aac581
No known key found for this signature in database
GPG Key ID: 51445F2ED08EBC7F
3 changed files with 25 additions and 18 deletions

View File

@ -47,10 +47,10 @@ class SemanticSearchConfig(FrigateBaseModel):
model_size: str = Field( model_size: str = Field(
default="small", title="The size of the embeddings model used." default="small", title="The size of the embeddings model used."
) )
device: str = Field( device: Optional[str] = Field(
default="CPU", default=None,
title="The device to use for semantic search.", title="The gpu id to use for semantic search.",
description="Use 'cpu' or 'gpu', to target a specific gpu use: '0', '1', etc.", description="This is an override, to target a specific gpu use: '0', '1', etc.",
) )
@ -92,10 +92,10 @@ class FaceRecognitionConfig(FrigateBaseModel):
blur_confidence_filter: bool = Field( blur_confidence_filter: bool = Field(
default=True, title="Apply blur quality filter to face confidence." default=True, title="Apply blur quality filter to face confidence."
) )
device: str = Field( device: Optional[str] = Field(
default="CPU", default=None,
title="The device to use for face recognition.", title="The gpu id to use for face recognition.",
description="Use 'cpu' or 'gpu', to target a specific gpu use: '0', '1', etc.", description="This is an override, to target a specific gpu use: '0', '1', etc.",
) )
@ -155,10 +155,10 @@ class LicensePlateRecognitionConfig(FrigateBaseModel):
default=False, default=False,
title="Save plates captured for LPR for debugging purposes.", title="Save plates captured for LPR for debugging purposes.",
) )
device: str = Field( device: Optional[str] = Field(
default="CPU", default=None,
title="The device to use for license plate recognition.", title="The gpu id to use for license plate recognition.",
description="Use 'cpu' or 'gpu', to target a specific gpu use: '0', '1', etc.", description="This is an override, to target a specific gpu use: '0', '1', etc.",
) )
@ -180,4 +180,4 @@ class CameraLicensePlateRecognitionConfig(FrigateBaseModel):
le=10, le=10,
) )
model_config = ConfigDict(extra="forbid", protected_namespaces=()) model_config = ConfigDict(extra="forbid", protected_namespaces=())

View File

@ -107,7 +107,8 @@ class Embeddings:
self.embedding = JinaV2Embedding( self.embedding = JinaV2Embedding(
model_size=self.config.semantic_search.model_size, model_size=self.config.semantic_search.model_size,
requestor=self.requestor, requestor=self.requestor,
device=self.config.semantic_search.device, device=config.semantic_search.device
or ("GPU" if config.semantic_search.model_size == "large" else "CPU"),
) )
self.text_embedding = lambda input_data: self.embedding( self.text_embedding = lambda input_data: self.embedding(
input_data, embedding_type="text" input_data, embedding_type="text"
@ -124,7 +125,8 @@ class Embeddings:
self.vision_embedding = JinaV1ImageEmbedding( self.vision_embedding = JinaV1ImageEmbedding(
model_size=config.semantic_search.model_size, model_size=config.semantic_search.model_size,
requestor=self.requestor, requestor=self.requestor,
device=self.config.semantic_search.device, device=config.semantic_search.device
or ("GPU" if config.semantic_search.model_size == "large" else "CPU"),
) )
def update_stats(self) -> None: def update_stats(self) -> None:
@ -414,4 +416,4 @@ class Embeddings:
finally: finally:
with self.reindex_lock: with self.reindex_lock:
self.reindex_running = False self.reindex_running = False
self.reindex_thread = None self.reindex_thread = None

View File

@ -148,7 +148,12 @@ class ArcfaceEmbedding(BaseEmbedding):
self.runner = ONNXModelRunner( self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
device=self.config.face_recognition.device, device=self.config.face_recognition.device
or (
"GPU"
if self.config.face_recognition.model_size == "large"
else "CPU"
),
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):
@ -184,4 +189,4 @@ class ArcfaceEmbedding(BaseEmbedding):
frame = np.transpose(frame, (2, 0, 1)) frame = np.transpose(frame, (2, 0, 1))
frame = np.expand_dims(frame, axis=0) frame = np.expand_dims(frame, axis=0)
return [{"data": frame}] return [{"data": frame}]