This commit is contained in:
OmriAx 2025-02-19 19:16:27 +02:00
parent 16ffabf51f
commit 11177292de

View File

@ -22,6 +22,7 @@ from frigate.detectors.detector_config import BaseDetectorConfig
from pydantic import BaseModel, Field
from typing_extensions import Literal
from typing import Optional
from functools import partial
logger = logging.getLogger(__name__)
@ -48,8 +49,8 @@ class HailoAsyncInference:
self.target = VDevice(params)
# Initialize HEF
self.hef = HEF(self.model_path)
self.infer_model = self.target.create_infer_model(self.model_path)
self.hef = HEF(self.config.model.path)
self.infer_model = self.target.create_infer_model(self.config.model.path)
self.infer_model.set_batch_size(1)
def infer(self):
@ -58,13 +59,25 @@ class HailoAsyncInference:
if batch_data is None:
break
bindings = []
for frame in batch_data:
binding = self.infer_model.create_bindings()
binding.input().set_buffer(frame)
bindings.append(binding)
with self.infer_model.configure() as configured_model:
bindings_list = []
for frame in batch_data:
# Create empty output buffers
output_buffers = {
output_info.name: np.empty(
self.infer_model.output(output_info.name).shape,
dtype=np.float32
)
for output_info in self.hef.get_output_vstream_infos()
}
# Create bindings using the configured model
binding = configured_model.create_bindings(output_buffers=output_buffers)
binding.input().set_buffer(frame)
bindings_list.append(binding)
# Run async inference on the configured model
configured_model.run_async(bindings_list, partial(self._callback, batch_data=batch_data))
self.infer_model.run_async(bindings, self._callback, batch_data)
def _callback(self, completion_info, bindings_list, batch_data):
if completion_info.exception:
@ -87,10 +100,11 @@ class HailoDetector(DetectionApi):
# Get the model path
model_path = self.check_and_prepare_model()
self.config.model.path = model_path
print(self.config.model.path)
# Initialize async inference with the correct model path
self.async_inference = HailoAsyncInference(detector_config)
self.async_inference.config.model.path = model_path
self.worker_thread = threading.Thread(target=self.async_inference.infer)
self.worker_thread.start()