mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-19 23:08:08 +02:00
Implement Wizard for Creating Classification Models (#20622)
* Implement extraction of images for classification state models * Add object classification dataset preparation * Add first step wizard * Update i18n * Add state classification image selection step * Improve box handling * Add object selector * Improve object cropping implementation * Fix state classification selection * Finalize training and image selection step * Cleanup * Design optimizations * Cleanup mobile styling * Update no models screen * Cleanups and fixes * Fix bugs * Improve model training and creation process * Cleanup * Dynamically add metrics for new model * Add loading when hitting continue * Improve image selection mechanism * Remove unused translation keys * Adjust wording * Add retry button for image generation * Make no models view more specific * Adjust plus icon * Adjust form label * Start with correct type selected * Cleanup sizing and more font colors * Small tweaks * Add tips and more info * Cleanup dialog sizing * Add cursor rule for frontend * Cleanup * remove underline * Lazy loading
This commit is contained in:
@@ -2,12 +2,15 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.config import FfmpegConfig
|
||||
from frigate.const import (
|
||||
CLIPS_DIR,
|
||||
MODEL_CACHE_DIR,
|
||||
@@ -15,7 +18,10 @@ from frigate.const import (
|
||||
UPDATE_MODEL_STATE,
|
||||
)
|
||||
from frigate.log import redirect_output_to_logger
|
||||
from frigate.models import Event, Recordings, ReviewSegment
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.image import get_image_from_recording
|
||||
from frigate.util.path import get_event_thumbnail_bytes
|
||||
from frigate.util.process import FrigateProcess
|
||||
|
||||
BATCH_SIZE = 16
|
||||
@@ -69,6 +75,7 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
logger.info(f"Kicking off classification training for {self.model_name}.")
|
||||
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
|
||||
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
num_classes = len(
|
||||
[
|
||||
d
|
||||
@@ -139,7 +146,6 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
f.write(tflite_model)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def kickoff_model_training(
|
||||
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
||||
) -> None:
|
||||
@@ -172,3 +178,520 @@ def kickoff_model_training(
|
||||
},
|
||||
)
|
||||
requestor.stop()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def collect_state_classification_examples(
|
||||
model_name: str, cameras: dict[str, tuple[float, float, float, float]]
|
||||
) -> None:
|
||||
"""
|
||||
Collect representative state classification examples from review items.
|
||||
|
||||
This function:
|
||||
1. Queries review items from specified cameras
|
||||
2. Selects 100 balanced timestamps across the data
|
||||
3. Extracts keyframes from recordings (cropped to specified regions)
|
||||
4. Selects 20 most visually distinct images
|
||||
5. Saves them to the dataset directory
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
cameras: Dict mapping camera names to normalized crop coordinates [x1, y1, x2, y2] (0-1)
|
||||
"""
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
temp_dir = os.path.join(dataset_dir, "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Step 1: Get review items for the cameras
|
||||
camera_names = list(cameras.keys())
|
||||
review_items = list(
|
||||
ReviewSegment.select()
|
||||
.where(ReviewSegment.camera.in_(camera_names))
|
||||
.where(ReviewSegment.end_time.is_null(False))
|
||||
.order_by(ReviewSegment.start_time.asc())
|
||||
)
|
||||
|
||||
if not review_items:
|
||||
logger.warning(f"No review items found for cameras: {camera_names}")
|
||||
return
|
||||
|
||||
# Step 2: Create balanced timestamp selection (100 samples)
|
||||
timestamps = _select_balanced_timestamps(review_items, target_count=100)
|
||||
|
||||
# Step 3: Extract keyframes from recordings with crops applied
|
||||
keyframes = _extract_keyframes(
|
||||
"/usr/lib/ffmpeg/7.0/bin/ffmpeg", timestamps, temp_dir, cameras
|
||||
)
|
||||
|
||||
# Step 4: Select 24 most visually distinct images (they're already cropped)
|
||||
distinct_images = _select_distinct_images(keyframes, target_count=24)
|
||||
|
||||
# Step 5: Save to train directory for later classification
|
||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
|
||||
saved_count = 0
|
||||
for idx, image_path in enumerate(distinct_images):
|
||||
dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg")
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
|
||||
if img is not None:
|
||||
cv2.imwrite(dest_path, img)
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save image {image_path}: {e}")
|
||||
|
||||
import shutil
|
||||
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp directory: {e}")
|
||||
|
||||
|
||||
def _select_balanced_timestamps(
|
||||
review_items: list[ReviewSegment], target_count: int = 100
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Select balanced timestamps from review items.
|
||||
|
||||
Strategy:
|
||||
- Group review items by camera and time of day
|
||||
- Sample evenly across groups to ensure diversity
|
||||
- For each selected review item, pick a random timestamp within its duration
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: camera, timestamp, review_item
|
||||
"""
|
||||
# Group by camera and hour of day for temporal diversity
|
||||
grouped = defaultdict(list)
|
||||
|
||||
for item in review_items:
|
||||
camera = item.camera
|
||||
# Group by 6-hour blocks for temporal diversity
|
||||
hour_block = int(item.start_time // (6 * 3600))
|
||||
key = f"{camera}_{hour_block}"
|
||||
grouped[key].append(item)
|
||||
|
||||
# Calculate how many samples per group
|
||||
num_groups = len(grouped)
|
||||
if num_groups == 0:
|
||||
return []
|
||||
|
||||
samples_per_group = max(1, target_count // num_groups)
|
||||
timestamps = []
|
||||
|
||||
# Sample from each group
|
||||
for group_items in grouped.values():
|
||||
# Take samples_per_group items from this group
|
||||
sample_size = min(samples_per_group, len(group_items))
|
||||
sampled_items = random.sample(group_items, sample_size)
|
||||
|
||||
for item in sampled_items:
|
||||
# Pick a random timestamp within the review item's duration
|
||||
duration = item.end_time - item.start_time
|
||||
if duration <= 0:
|
||||
continue
|
||||
|
||||
# Sample from middle 80% to avoid edge artifacts
|
||||
offset = random.uniform(duration * 0.1, duration * 0.9)
|
||||
timestamp = item.start_time + offset
|
||||
|
||||
timestamps.append(
|
||||
{
|
||||
"camera": item.camera,
|
||||
"timestamp": timestamp,
|
||||
"review_item": item,
|
||||
}
|
||||
)
|
||||
|
||||
# If we don't have enough, sample more from larger groups
|
||||
while len(timestamps) < target_count and len(timestamps) < len(review_items):
|
||||
for group_items in grouped.values():
|
||||
if len(timestamps) >= target_count:
|
||||
break
|
||||
|
||||
# Pick a random item not already sampled
|
||||
item = random.choice(group_items)
|
||||
duration = item.end_time - item.start_time
|
||||
if duration <= 0:
|
||||
continue
|
||||
|
||||
offset = random.uniform(duration * 0.1, duration * 0.9)
|
||||
timestamp = item.start_time + offset
|
||||
|
||||
# Check if we already have a timestamp near this one
|
||||
if not any(abs(t["timestamp"] - timestamp) < 1.0 for t in timestamps):
|
||||
timestamps.append(
|
||||
{
|
||||
"camera": item.camera,
|
||||
"timestamp": timestamp,
|
||||
"review_item": item,
|
||||
}
|
||||
)
|
||||
|
||||
return timestamps[:target_count]
|
||||
|
||||
|
||||
def _extract_keyframes(
|
||||
ffmpeg_path: str,
|
||||
timestamps: list[dict],
|
||||
output_dir: str,
|
||||
camera_crops: dict[str, tuple[float, float, float, float]],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Extract keyframes from recordings at specified timestamps and crop to specified regions.
|
||||
|
||||
Args:
|
||||
ffmpeg_path: Path to ffmpeg binary
|
||||
timestamps: List of timestamp dicts from _select_balanced_timestamps
|
||||
output_dir: Directory to save extracted frames
|
||||
camera_crops: Dict mapping camera names to normalized crop coordinates [x1, y1, x2, y2] (0-1)
|
||||
|
||||
Returns:
|
||||
List of paths to successfully extracted and cropped keyframe images
|
||||
"""
|
||||
keyframe_paths = []
|
||||
|
||||
for idx, ts_info in enumerate(timestamps):
|
||||
camera = ts_info["camera"]
|
||||
timestamp = ts_info["timestamp"]
|
||||
|
||||
if camera not in camera_crops:
|
||||
logger.warning(f"No crop coordinates for camera {camera}")
|
||||
continue
|
||||
|
||||
norm_x1, norm_y1, norm_x2, norm_y2 = camera_crops[camera]
|
||||
|
||||
try:
|
||||
recording = (
|
||||
Recordings.select()
|
||||
.where(
|
||||
(timestamp >= Recordings.start_time)
|
||||
& (timestamp <= Recordings.end_time)
|
||||
& (Recordings.camera == camera)
|
||||
)
|
||||
.order_by(Recordings.start_time.desc())
|
||||
.limit(1)
|
||||
.get()
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
relative_time = timestamp - recording.start_time
|
||||
|
||||
try:
|
||||
config = FfmpegConfig(path="/usr/lib/ffmpeg/7.0")
|
||||
image_data = get_image_from_recording(
|
||||
config,
|
||||
recording.path,
|
||||
relative_time,
|
||||
codec="mjpeg",
|
||||
height=None,
|
||||
)
|
||||
|
||||
if image_data:
|
||||
nparr = np.frombuffer(image_data, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if img is not None:
|
||||
height, width = img.shape[:2]
|
||||
|
||||
x1 = int(norm_x1 * width)
|
||||
y1 = int(norm_y1 * height)
|
||||
x2 = int(norm_x2 * width)
|
||||
y2 = int(norm_y2 * height)
|
||||
|
||||
x1_clipped = max(0, min(x1, width))
|
||||
y1_clipped = max(0, min(y1, height))
|
||||
x2_clipped = max(0, min(x2, width))
|
||||
y2_clipped = max(0, min(y2, height))
|
||||
|
||||
if x2_clipped > x1_clipped and y2_clipped > y1_clipped:
|
||||
cropped = img[y1_clipped:y2_clipped, x1_clipped:x2_clipped]
|
||||
resized = cv2.resize(cropped, (224, 224))
|
||||
|
||||
output_path = os.path.join(output_dir, f"frame_{idx:04d}.jpg")
|
||||
cv2.imwrite(output_path, resized)
|
||||
keyframe_paths.append(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to extract frame from {recording.path} at {relative_time}s: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return keyframe_paths
|
||||
|
||||
|
||||
def _select_distinct_images(
|
||||
image_paths: list[str], target_count: int = 20
|
||||
) -> list[str]:
|
||||
"""
|
||||
Select the most visually distinct images from a set of keyframes.
|
||||
|
||||
Uses a greedy algorithm based on image histograms:
|
||||
1. Start with a random image
|
||||
2. Iteratively add the image that is most different from already selected images
|
||||
3. Difference is measured using histogram comparison
|
||||
|
||||
Args:
|
||||
image_paths: List of paths to candidate images
|
||||
target_count: Number of distinct images to select
|
||||
|
||||
Returns:
|
||||
List of paths to selected images
|
||||
"""
|
||||
if len(image_paths) <= target_count:
|
||||
return image_paths
|
||||
|
||||
histograms = {}
|
||||
valid_paths = []
|
||||
|
||||
for path in image_paths:
|
||||
try:
|
||||
img = cv2.imread(path)
|
||||
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
hist = cv2.calcHist(
|
||||
[hsv], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256]
|
||||
)
|
||||
hist = cv2.normalize(hist, hist).flatten()
|
||||
histograms[path] = hist
|
||||
valid_paths.append(path)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to process image {path}: {e}")
|
||||
continue
|
||||
|
||||
if len(valid_paths) <= target_count:
|
||||
return valid_paths
|
||||
|
||||
selected = []
|
||||
first_image = random.choice(valid_paths)
|
||||
selected.append(first_image)
|
||||
remaining = [p for p in valid_paths if p != first_image]
|
||||
|
||||
while len(selected) < target_count and remaining:
|
||||
max_min_distance = -1
|
||||
best_candidate = None
|
||||
|
||||
for candidate in remaining:
|
||||
min_distance = float("inf")
|
||||
|
||||
for selected_img in selected:
|
||||
distance = cv2.compareHist(
|
||||
histograms[candidate],
|
||||
histograms[selected_img],
|
||||
cv2.HISTCMP_BHATTACHARYYA,
|
||||
)
|
||||
min_distance = min(min_distance, distance)
|
||||
|
||||
if min_distance > max_min_distance:
|
||||
max_min_distance = min_distance
|
||||
best_candidate = candidate
|
||||
|
||||
if best_candidate:
|
||||
selected.append(best_candidate)
|
||||
remaining.remove(best_candidate)
|
||||
else:
|
||||
break
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
@staticmethod
|
||||
def collect_object_classification_examples(
|
||||
model_name: str,
|
||||
label: str,
|
||||
) -> None:
|
||||
"""
|
||||
Collect representative object classification examples from event thumbnails.
|
||||
|
||||
This function:
|
||||
1. Queries events for the specified label
|
||||
2. Selects 100 balanced events across different cameras and times
|
||||
3. Retrieves thumbnails for selected events (with 33% center crop applied)
|
||||
4. Selects 24 most visually distinct thumbnails
|
||||
5. Saves to dataset directory
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
label: Object label to collect (e.g., "person", "car")
|
||||
cameras: List of camera names to collect examples from
|
||||
"""
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
temp_dir = os.path.join(dataset_dir, "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Step 1: Query events for the specified label and cameras
|
||||
events = list(
|
||||
Event.select().where((Event.label == label)).order_by(Event.start_time.asc())
|
||||
)
|
||||
|
||||
if not events:
|
||||
logger.warning(f"No events found for label '{label}'")
|
||||
return
|
||||
|
||||
logger.debug(f"Found {len(events)} events")
|
||||
|
||||
# Step 2: Select balanced events (100 samples)
|
||||
selected_events = _select_balanced_events(events, target_count=100)
|
||||
logger.debug(f"Selected {len(selected_events)} events")
|
||||
|
||||
# Step 3: Extract thumbnails from events
|
||||
thumbnails = _extract_event_thumbnails(selected_events, temp_dir)
|
||||
logger.debug(f"Successfully extracted {len(thumbnails)} thumbnails")
|
||||
|
||||
# Step 4: Select 24 most visually distinct thumbnails
|
||||
distinct_images = _select_distinct_images(thumbnails, target_count=24)
|
||||
logger.debug(f"Selected {len(distinct_images)} distinct images")
|
||||
|
||||
# Step 5: Save to train directory for later classification
|
||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
|
||||
saved_count = 0
|
||||
for idx, image_path in enumerate(distinct_images):
|
||||
dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg")
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
|
||||
if img is not None:
|
||||
cv2.imwrite(dest_path, img)
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save image {image_path}: {e}")
|
||||
|
||||
import shutil
|
||||
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp directory: {e}")
|
||||
|
||||
logger.debug(
|
||||
f"Successfully collected {saved_count} classification examples in {train_dir}"
|
||||
)
|
||||
|
||||
|
||||
def _select_balanced_events(
|
||||
events: list[Event], target_count: int = 100
|
||||
) -> list[Event]:
|
||||
"""
|
||||
Select balanced events from the event list.
|
||||
|
||||
Strategy:
|
||||
- Group events by camera and time of day
|
||||
- Sample evenly across groups to ensure diversity
|
||||
- Prioritize events with higher scores
|
||||
|
||||
Returns:
|
||||
List of selected events
|
||||
"""
|
||||
grouped = defaultdict(list)
|
||||
|
||||
for event in events:
|
||||
camera = event.camera
|
||||
hour_block = int(event.start_time // (6 * 3600))
|
||||
key = f"{camera}_{hour_block}"
|
||||
grouped[key].append(event)
|
||||
|
||||
num_groups = len(grouped)
|
||||
if num_groups == 0:
|
||||
return []
|
||||
|
||||
samples_per_group = max(1, target_count // num_groups)
|
||||
selected = []
|
||||
|
||||
for group_events in grouped.values():
|
||||
sorted_events = sorted(
|
||||
group_events,
|
||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
sample_size = min(samples_per_group, len(sorted_events))
|
||||
selected.extend(sorted_events[:sample_size])
|
||||
|
||||
if len(selected) < target_count:
|
||||
remaining = [e for e in events if e not in selected]
|
||||
remaining_sorted = sorted(
|
||||
remaining,
|
||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
||||
reverse=True,
|
||||
)
|
||||
needed = target_count - len(selected)
|
||||
selected.extend(remaining_sorted[:needed])
|
||||
|
||||
return selected[:target_count]
|
||||
|
||||
|
||||
def _extract_event_thumbnails(events: list[Event], output_dir: str) -> list[str]:
|
||||
"""
|
||||
Extract thumbnails from events and save to disk.
|
||||
|
||||
Args:
|
||||
events: List of Event objects
|
||||
output_dir: Directory to save thumbnails
|
||||
|
||||
Returns:
|
||||
List of paths to successfully extracted thumbnail images
|
||||
"""
|
||||
thumbnail_paths = []
|
||||
|
||||
for idx, event in enumerate(events):
|
||||
try:
|
||||
thumbnail_bytes = get_event_thumbnail_bytes(event)
|
||||
|
||||
if thumbnail_bytes:
|
||||
nparr = np.frombuffer(thumbnail_bytes, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if img is not None:
|
||||
height, width = img.shape[:2]
|
||||
|
||||
crop_size = 1.0
|
||||
if event.data and "box" in event.data and "region" in event.data:
|
||||
box = event.data["box"]
|
||||
region = event.data["region"]
|
||||
|
||||
if len(box) == 4 and len(region) == 4:
|
||||
box_w, box_h = box[2], box[3]
|
||||
region_w, region_h = region[2], region[3]
|
||||
|
||||
box_area = (box_w * box_h) / (region_w * region_h)
|
||||
|
||||
if box_area < 0.05:
|
||||
crop_size = 0.4
|
||||
elif box_area < 0.10:
|
||||
crop_size = 0.5
|
||||
elif box_area < 0.20:
|
||||
crop_size = 0.65
|
||||
elif box_area < 0.35:
|
||||
crop_size = 0.80
|
||||
else:
|
||||
crop_size = 0.95
|
||||
|
||||
crop_width = int(width * crop_size)
|
||||
crop_height = int(height * crop_size)
|
||||
|
||||
x1 = (width - crop_width) // 2
|
||||
y1 = (height - crop_height) // 2
|
||||
x2 = x1 + crop_width
|
||||
y2 = y1 + crop_height
|
||||
|
||||
cropped = img[y1:y2, x1:x2]
|
||||
resized = cv2.resize(cropped, (224, 224))
|
||||
output_path = os.path.join(output_dir, f"thumbnail_{idx:04d}.jpg")
|
||||
cv2.imwrite(output_path, resized)
|
||||
thumbnail_paths.append(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract thumbnail for event {event.id}: {e}")
|
||||
continue
|
||||
|
||||
return thumbnail_paths
|
||||
|
||||
Reference in New Issue
Block a user