Classification improvements (#19020)

* Move classification training to full process

* Sort class images
This commit is contained in:
Nicolas Mowen 2025-07-07 07:36:06 -06:00 committed by GitHub
parent 0f4cac736a
commit 2b4a773f9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 104 additions and 93 deletions

View File

@ -20,7 +20,19 @@ LEARNING_RATE = 0.001
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def __generate_representative_dataset_factory(dataset_dir: str): class ClassificationTrainingProcess(FrigateProcess):
def __init__(self, model_name: str) -> None:
super().__init__(
stop_event=None,
name=f"model_training:{model_name}",
)
self.model_name = model_name
def run(self) -> None:
self.pre_run_setup()
self.__train_classification_model()
def __generate_representative_dataset_factory(self, dataset_dir: str):
def generate_representative_dataset(): def generate_representative_dataset():
image_paths = [] image_paths = []
for root, dirs, files in os.walk(dataset_dir): for root, dirs, files in os.walk(dataset_dir):
@ -38,9 +50,8 @@ def __generate_representative_dataset_factory(dataset_dir: str):
return generate_representative_dataset return generate_representative_dataset
@redirect_output_to_logger(logger, logging.DEBUG)
@redirect_output_to_logger(logger, logging.DEBUG) def __train_classification_model(self) -> bool:
def __train_classification_model(model_name: str) -> bool:
"""Train a classification model.""" """Train a classification model."""
# import in the function so that tensorflow is not initialized multiple times # import in the function so that tensorflow is not initialized multiple times
@ -49,9 +60,9 @@ def __train_classification_model(model_name: str) -> bool:
from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.preprocessing.image import ImageDataGenerator
logger.info(f"Kicking off classification training for {model_name}.") logger.info(f"Kicking off classification training for {self.model_name}.")
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, model_name) model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
num_classes = len( num_classes = len(
[ [
d d
@ -109,8 +120,8 @@ def __train_classification_model(model_name: str) -> bool:
# convert model to tflite # convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model) converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = __generate_representative_dataset_factory( converter.representative_dataset = (
dataset_dir self.__generate_representative_dataset_factory(dataset_dir)
) )
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8 converter.inference_input_type = tf.uint8
@ -138,12 +149,7 @@ def kickoff_model_training(
# run training in sub process so that # run training in sub process so that
# tensorflow will free CPU / GPU memory # tensorflow will free CPU / GPU memory
# upon training completion # upon training completion
training_process = FrigateProcess( training_process = ClassificationTrainingProcess(model_name)
None,
target=__train_classification_model,
name=f"model_training:{model_name}",
args=(model_name,),
)
training_process.start() training_process.start()
training_process.join() training_process.join()

View File

@ -577,9 +577,14 @@ function DatasetGrid({
}: DatasetGridProps) { }: DatasetGridProps) {
const { t } = useTranslation(["views/classificationModel"]); const { t } = useTranslation(["views/classificationModel"]);
const classData = useMemo(
() => images.sort((a, b) => a.localeCompare(b)),
[images],
);
return ( return (
<div className="flex flex-wrap gap-2 overflow-y-auto p-2"> <div className="flex flex-wrap gap-2 overflow-y-auto p-2">
{images.map((image) => ( {classData.map((image) => (
<div <div
className={cn( className={cn(
"flex w-60 cursor-pointer flex-col gap-2 rounded-lg bg-card outline outline-[3px]", "flex w-60 cursor-pointer flex-col gap-2 rounded-lg bg-card outline outline-[3px]",