import logging
import os
import threading
import time
from pathlib import Path
from typing import Callable, List

import requests

from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum

logger = logging.getLogger(__name__)


class FileLock:
    def __init__(self, path):
        self.path = path
        self.lock_file = f"{path}.lock"

        # we have not acquired the lock yet so it should not exist
        if os.path.exists(self.lock_file):
            try:
                os.remove(self.lock_file)
            except Exception:
                pass

    def acquire(self):
        parent_dir = os.path.dirname(self.lock_file)
        os.makedirs(parent_dir, exist_ok=True)

        while True:
            try:
                with open(self.lock_file, "x"):
                    return
            except FileExistsError:
                time.sleep(0.1)

    def release(self):
        try:
            os.remove(self.lock_file)
        except FileNotFoundError:
            pass


class ModelDownloader:
    def __init__(
        self,
        model_name: str,
        download_path: str,
        file_names: List[str],
        download_func: Callable[[str], None],
        silent: bool = False,
    ):
        self.model_name = model_name
        self.download_path = download_path
        self.file_names = file_names
        self.download_func = download_func
        self.silent = silent
        self.requestor = InterProcessRequestor()
        self.download_thread = None
        self.download_complete = threading.Event()

    def ensure_model_files(self):
        self.mark_files_state(
            self.requestor,
            self.model_name,
            self.file_names,
            ModelStatusTypesEnum.downloading,
        )
        self.download_thread = threading.Thread(
            target=self._download_models,
            name=f"_download_model_{self.model_name}",
            daemon=True,
        )
        self.download_thread.start()

    def _download_models(self):
        for file_name in self.file_names:
            path = os.path.join(self.download_path, file_name)
            lock = FileLock(path)

            if not os.path.exists(path):
                lock.acquire()
                try:
                    if not os.path.exists(path):
                        self.download_func(path)
                finally:
                    lock.release()

            self.requestor.send_data(
                UPDATE_MODEL_STATE,
                {
                    "model": f"{self.model_name}-{file_name}",
                    "state": ModelStatusTypesEnum.downloaded,
                },
            )

        self.requestor.stop()
        self.download_complete.set()

    @staticmethod
    def download_from_url(url: str, save_path: str, silent: bool = False):
        temporary_filename = Path(save_path).with_name(
            os.path.basename(save_path) + ".part"
        )
        temporary_filename.parent.mkdir(parents=True, exist_ok=True)

        if not silent:
            logger.info(f"Downloading model file from: {url}")

        try:
            with requests.get(url, stream=True, allow_redirects=True) as r:
                r.raise_for_status()
                with open(temporary_filename, "wb") as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)

            temporary_filename.rename(save_path)
        except Exception as e:
            logger.error(f"Error downloading model: {str(e)}")
            raise

        if not silent:
            logger.info(f"Downloading complete: {url}")

    @staticmethod
    def mark_files_state(
        requestor: InterProcessRequestor,
        model_name: str,
        files: list[str],
        state: ModelStatusTypesEnum,
    ) -> None:
        for file_name in files:
            requestor.send_data(
                UPDATE_MODEL_STATE,
                {
                    "model": f"{model_name}-{file_name}",
                    "state": state,
                },
            )

    def wait_for_download(self):
        self.download_complete.wait()