mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-04 13:47:37 +02:00
Bug Fixing
This commit is contained in:
parent
3cc1382439
commit
16ffabf51f
@ -27,29 +27,15 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
DETECTOR_KEY = "hailo8l"
|
DETECTOR_KEY = "hailo8l"
|
||||||
|
|
||||||
def get_device_architecture():
|
|
||||||
"""Get the device architecture from hailortcli."""
|
|
||||||
try:
|
|
||||||
result = subprocess.run(['hailortcli', 'fw-control', 'identify'], capture_output=True, text=True)
|
|
||||||
for line in result.stdout.split('\n'):
|
|
||||||
if "Device Architecture" in line:
|
|
||||||
return line.split(':')[1].strip().lower()
|
|
||||||
except Exception:
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
path: Optional[str] = Field(default=None, title="Model Path")
|
path: str = Field(default=None, title="Model Path")
|
||||||
type: str = Field(default="yolov8s", title="Model Type")
|
|
||||||
width: int = Field(default=640, title="Model Width")
|
|
||||||
height: int = Field(default=640, title="Model Height")
|
|
||||||
score_threshold: float = Field(default=0.3, title="Score Threshold")
|
|
||||||
max_detections: int = Field(default=30, title="Maximum Detections")
|
|
||||||
input_tensor: str = Field(default="input_tensor", title="Input Tensor Name")
|
|
||||||
input_pixel_format: str = Field(default="RGB", title="Input Pixel Format")
|
|
||||||
|
|
||||||
class HailoDetectorConfig(BaseDetectorConfig):
|
class HailoDetectorConfig(BaseDetectorConfig):
|
||||||
type: Literal[DETECTOR_KEY]
|
type: Literal[DETECTOR_KEY]
|
||||||
device: str = Field(default="PCIe", title="Device Type")
|
device: str = Field(default="PCIe", title="Device Type")
|
||||||
|
url: Optional[str] = Field(default=None, title="Model URL")
|
||||||
|
dir: Optional[str] = Field(default=None, title="Model Directory")
|
||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
|
|
||||||
class HailoAsyncInference:
|
class HailoAsyncInference:
|
||||||
@ -60,8 +46,10 @@ class HailoAsyncInference:
|
|||||||
params = VDevice.create_params()
|
params = VDevice.create_params()
|
||||||
params.scheduling_algorithm = HailoSchedulingAlgorithm.ROUND_ROBIN
|
params.scheduling_algorithm = HailoSchedulingAlgorithm.ROUND_ROBIN
|
||||||
self.target = VDevice(params)
|
self.target = VDevice(params)
|
||||||
self.hef = HEF(self.config.model.path)
|
|
||||||
self.infer_model = self.target.create_infer_model(self.config.model.path)
|
# Initialize HEF
|
||||||
|
self.hef = HEF(self.model_path)
|
||||||
|
self.infer_model = self.target.create_infer_model(self.model_path)
|
||||||
self.infer_model.set_batch_size(1)
|
self.infer_model.set_batch_size(1)
|
||||||
|
|
||||||
def infer(self):
|
def infer(self):
|
||||||
@ -90,46 +78,56 @@ class HailoAsyncInference:
|
|||||||
|
|
||||||
class HailoDetector(DetectionApi):
|
class HailoDetector(DetectionApi):
|
||||||
type_key = DETECTOR_KEY
|
type_key = DETECTOR_KEY
|
||||||
|
DEFAULT_CACHE_DIR = "/config/model_cache/"
|
||||||
|
|
||||||
def __init__(self, config: HailoDetectorConfig):
|
|
||||||
super().__init__()
|
def __init__(self, detector_config: HailoDetectorConfig):
|
||||||
self.async_inference = HailoAsyncInference(config)
|
super().__init__(detector_config)
|
||||||
|
self.config = detector_config
|
||||||
|
|
||||||
|
# Get the model path
|
||||||
|
model_path = self.check_and_prepare_model()
|
||||||
|
|
||||||
|
# Initialize async inference with the correct model path
|
||||||
|
self.async_inference = HailoAsyncInference(detector_config)
|
||||||
|
self.async_inference.config.model.path = model_path
|
||||||
self.worker_thread = threading.Thread(target=self.async_inference.infer)
|
self.worker_thread = threading.Thread(target=self.async_inference.infer)
|
||||||
self.worker_thread.start()
|
self.worker_thread.start()
|
||||||
|
|
||||||
# Determine device architecture
|
|
||||||
self.device_architecture = get_device_architecture()
|
|
||||||
if self.device_architecture not in ["hailo8", "hailo8l"]:
|
|
||||||
raise RuntimeError(f"Unsupported device architecture: {self.device_architecture}")
|
|
||||||
logger.info(f"Device architecture detected: {self.device_architecture}")
|
|
||||||
|
|
||||||
# Ensure the model is available
|
def check_and_prepare_model(self) -> str:
|
||||||
self.cache_dir = "/config/model_cache/h8l_cache"
|
"""
|
||||||
self.expected_model_filename = f"{config.model.type}.hef"
|
Check if model exists at specified path, download from URL if needed.
|
||||||
self.check_and_prepare_model()
|
Returns the final model path to use.
|
||||||
|
"""
|
||||||
def check_and_prepare_model(self):
|
|
||||||
# Ensure cache directory exists
|
# Ensure cache directory exists
|
||||||
if not os.path.exists(self.cache_dir):
|
if not os.path.exists(self.DEFAULT_CACHE_DIR):
|
||||||
os.makedirs(self.cache_dir)
|
os.makedirs(self.DEFAULT_CACHE_DIR)
|
||||||
|
|
||||||
# Check for the expected model file
|
model_path = self.config.dir # the directory path of the model
|
||||||
model_file_path = os.path.join(self.cache_dir, self.expected_model_filename)
|
model_url = self.config.url # the url of the model
|
||||||
self.async_inference.config.model.path = model_file_path
|
|
||||||
|
|
||||||
if not os.path.isfile(model_file_path):
|
if (model_path and os.path.isfile(model_path)):
|
||||||
if self.async_inference.config.model.path:
|
return model_path
|
||||||
logger.info(
|
|
||||||
f"A model file was not found at {model_file_path}, Downloading one from the provided URL."
|
if (model_url):
|
||||||
)
|
model_filename = os.path.basename(model_url)
|
||||||
urllib.request.urlretrieve(self.async_inference.config.model.path, model_file_path)
|
model_file_path = os.path.join(self.DEFAULT_CACHE_DIR, model_filename)
|
||||||
logger.info(f"A model file was downloaded to {model_file_path}.")
|
if os.path.isfile(model_file_path):
|
||||||
|
return model_file_path
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Model file path is missing and no URL is provided.")
|
logger.info(f"Downloading model from URL: {model_url}")
|
||||||
else:
|
try:
|
||||||
logger.info(
|
urllib.request.urlretrieve(model_url, model_file_path)
|
||||||
f"A model file already exists at {model_file_path} not downloading one."
|
logger.info(f"Model downloaded successfully to: {model_file_path}")
|
||||||
)
|
return model_file_path
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download model: {str(e)}")
|
||||||
|
raise RuntimeError(f"Failed to download model from {model_url}")
|
||||||
|
raise RuntimeError("No valid model path or URL provided")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input):
|
||||||
"""
|
"""
|
||||||
@ -175,7 +173,7 @@ class HailoDetector(DetectionApi):
|
|||||||
def _process_yolo(self, raw_output, version):
|
def _process_yolo(self, raw_output, version):
|
||||||
detections = []
|
detections = []
|
||||||
for detection in raw_output[1]:
|
for detection in raw_output[1]:
|
||||||
confidence = detection[4] if version == "8" else np.max(detection[5:])
|
confidence = detection[4]
|
||||||
if confidence >= self.async_inference.config.model.score_threshold:
|
if confidence >= self.async_inference.config.model.score_threshold:
|
||||||
x, y, w, h = detection[:4]
|
x, y, w, h = detection[:4]
|
||||||
ymin, xmin, ymax, xmax = y - h / 2, x - w / 2, y + h / 2, x + w / 2
|
ymin, xmin, ymax, xmax = y - h / 2, x - w / 2, y + h / 2, x + w / 2
|
||||||
|
Loading…
Reference in New Issue
Block a user