* Fix environment vars reading

* fix yaml returning none

* Assume rocm model is onnx despite file extension
This commit is contained in:
Nicolas Mowen 2024-10-29 14:34:07 -06:00 committed by GitHub
parent 73da3d9b20
commit 357ce0382e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 5 deletions

View File

@ -23,7 +23,7 @@ EnvString = Annotated[str, AfterValidator(validate_env_string)]
def validate_env_vars(v: dict[str, str], info: ValidationInfo) -> dict[str, str]: def validate_env_vars(v: dict[str, str], info: ValidationInfo) -> dict[str, str]:
if isinstance(info.context, dict) and info.context.get("install", False): if isinstance(info.context, dict) and info.context.get("install", False):
for k, v in v: for k, v in v.items():
os.environ[k] = v os.environ[k] = v
return v return v

View File

@ -98,9 +98,7 @@ class ROCmDetector(DetectionApi):
else: else:
logger.info(f"AMD/ROCm: loading model from {path}") logger.info(f"AMD/ROCm: loading model from {path}")
if path.endswith(".onnx"): if (
self.model = migraphx.parse_onnx(path)
elif (
path.endswith(".tf") path.endswith(".tf")
or path.endswith(".tf2") or path.endswith(".tf2")
or path.endswith(".tflite") or path.endswith(".tflite")
@ -108,7 +106,7 @@ class ROCmDetector(DetectionApi):
# untested # untested
self.model = migraphx.parse_tf(path) self.model = migraphx.parse_tf(path)
else: else:
raise Exception(f"AMD/ROCm: unknown model format {path}") self.model = migraphx.parse_onnx(path)
logger.info("AMD/ROCm: compiling the model") logger.info("AMD/ROCm: compiling the model")

View File

@ -29,6 +29,10 @@ def migrate_frigate_config(config_file: str):
with open(config_file, "r") as f: with open(config_file, "r") as f:
config: dict[str, dict[str, any]] = yaml.load(f) config: dict[str, dict[str, any]] = yaml.load(f)
if config is None:
logger.error(f"Failed to load config at {config_file}")
return
previous_version = str(config.get("version", "0.13")) previous_version = str(config.get("version", "0.13"))
if previous_version == CURRENT_CONFIG_VERSION: if previous_version == CURRENT_CONFIG_VERSION: