Implement API to train classification models (#18475)

This commit is contained in:
Nicolas Mowen
2025-05-29 17:51:32 -06:00
committed by Blake Blackshear
parent 6dc36fcbb4
commit 2c7b71b16e
8 changed files with 219 additions and 19 deletions

View File

@@ -7,7 +7,7 @@ import shutil
from typing import Any
import cv2
from fastapi import APIRouter, Depends, Request, UploadFile
from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile
from fastapi.responses import JSONResponse
from pathvalidate import sanitize_filename
from peewee import DoesNotExist
@@ -19,10 +19,12 @@ from frigate.api.defs.request.classification_body import (
RenameFaceBody,
)
from frigate.api.defs.tags import Tags
from frigate.config import FrigateConfig
from frigate.config.camera import DetectConfig
from frigate.const import FACE_DIR
from frigate.const import FACE_DIR, MODEL_CACHE_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.models import Event
from frigate.util.classification import train_classification_model
from frigate.util.path import get_event_snapshot
logger = logging.getLogger(__name__)
@@ -442,3 +444,32 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody):
},
status_code=500,
)
# custom classification training
@router.post("/classification/{name}/train")
async def train_configured_model(
request: Request, name: str, background_tasks: BackgroundTasks
):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
return JSONResponse(
content=(
{
"success": False,
"message": f"{name} is not a known classification model.",
}
),
status_code=404,
)
background_tasks.add_task(
train_classification_model, os.path.join(MODEL_CACHE_DIR, name)
)
return JSONResponse(
content={"success": True, "message": "Started classification model training."},
status_code=200,
)