mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-03-09 00:16:54 +01:00
Fix
This commit is contained in:
parent
16ffabf51f
commit
11177292de
@ -22,6 +22,7 @@ from frigate.detectors.detector_config import BaseDetectorConfig
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -48,8 +49,8 @@ class HailoAsyncInference:
|
||||
self.target = VDevice(params)
|
||||
|
||||
# Initialize HEF
|
||||
self.hef = HEF(self.model_path)
|
||||
self.infer_model = self.target.create_infer_model(self.model_path)
|
||||
self.hef = HEF(self.config.model.path)
|
||||
self.infer_model = self.target.create_infer_model(self.config.model.path)
|
||||
self.infer_model.set_batch_size(1)
|
||||
|
||||
def infer(self):
|
||||
@ -58,13 +59,25 @@ class HailoAsyncInference:
|
||||
if batch_data is None:
|
||||
break
|
||||
|
||||
bindings = []
|
||||
for frame in batch_data:
|
||||
binding = self.infer_model.create_bindings()
|
||||
binding.input().set_buffer(frame)
|
||||
bindings.append(binding)
|
||||
with self.infer_model.configure() as configured_model:
|
||||
bindings_list = []
|
||||
for frame in batch_data:
|
||||
# Create empty output buffers
|
||||
output_buffers = {
|
||||
output_info.name: np.empty(
|
||||
self.infer_model.output(output_info.name).shape,
|
||||
dtype=np.float32
|
||||
)
|
||||
for output_info in self.hef.get_output_vstream_infos()
|
||||
}
|
||||
# Create bindings using the configured model
|
||||
binding = configured_model.create_bindings(output_buffers=output_buffers)
|
||||
binding.input().set_buffer(frame)
|
||||
bindings_list.append(binding)
|
||||
|
||||
# Run async inference on the configured model
|
||||
configured_model.run_async(bindings_list, partial(self._callback, batch_data=batch_data))
|
||||
|
||||
self.infer_model.run_async(bindings, self._callback, batch_data)
|
||||
|
||||
def _callback(self, completion_info, bindings_list, batch_data):
|
||||
if completion_info.exception:
|
||||
@ -87,10 +100,11 @@ class HailoDetector(DetectionApi):
|
||||
|
||||
# Get the model path
|
||||
model_path = self.check_and_prepare_model()
|
||||
self.config.model.path = model_path
|
||||
print(self.config.model.path)
|
||||
|
||||
# Initialize async inference with the correct model path
|
||||
self.async_inference = HailoAsyncInference(detector_config)
|
||||
self.async_inference.config.model.path = model_path
|
||||
self.worker_thread = threading.Thread(target=self.async_inference.infer)
|
||||
self.worker_thread.start()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user