mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-19 23:08:08 +02:00
Improve object classification (#18908)
* Ui improvements * Improve image cropping and model saving * Improve naming * Add logs for training * Improve model labeling * Don't set sub label for none object classification * Cleanup
This commit is contained in:
committed by
Blake Blackshear
parent
ceeb6543f5
commit
13fb7bc260
@@ -187,7 +187,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
super().__init__(config, metrics)
|
||||
self.model_config = model_config
|
||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||
self.train_dir = os.path.join(self.model_dir, "train")
|
||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||
self.interpreter: Interpreter = None
|
||||
self.sub_label_publisher = sub_label_publisher
|
||||
self.tensor_input_details: dict[str, Any] = None
|
||||
@@ -232,20 +232,23 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
obj_data["box"][1],
|
||||
obj_data["box"][2],
|
||||
obj_data["box"][3],
|
||||
224,
|
||||
max(
|
||||
obj_data["box"][1] - obj_data["box"][0],
|
||||
obj_data["box"][3] - obj_data["box"][2],
|
||||
),
|
||||
1.0,
|
||||
)
|
||||
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
||||
input = rgb[
|
||||
crop = rgb[
|
||||
y:y2,
|
||||
x:x2,
|
||||
]
|
||||
|
||||
if input.shape != (224, 224):
|
||||
input = cv2.resize(input, (224, 224))
|
||||
if crop.shape != (224, 224):
|
||||
crop = cv2.resize(crop, (224, 224))
|
||||
|
||||
input = np.expand_dims(input, axis=0)
|
||||
input = np.expand_dims(crop, axis=0)
|
||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||
self.interpreter.invoke()
|
||||
res: np.ndarray = self.interpreter.get_tensor(
|
||||
@@ -259,7 +262,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
|
||||
write_classification_attempt(
|
||||
self.train_dir,
|
||||
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
|
||||
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
|
||||
now,
|
||||
self.labelmap[best_id],
|
||||
score,
|
||||
@@ -269,12 +272,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
logger.debug(f"Score {score} is worse than previous score {previous_score}")
|
||||
return
|
||||
|
||||
self.sub_label_publisher.publish(
|
||||
EventMetadataTypeEnum.sub_label,
|
||||
(obj_data["id"], self.labelmap[best_id], score),
|
||||
)
|
||||
sub_label = self.labelmap[best_id]
|
||||
self.detected_objects[obj_data["id"]] = score
|
||||
|
||||
if sub_label != "none":
|
||||
self.sub_label_publisher.publish(
|
||||
EventMetadataTypeEnum.sub_label,
|
||||
(obj_data["id"], sub_label, score),
|
||||
)
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
|
||||
Reference in New Issue
Block a user