/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>

#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif

using half   = half_float::half;
namespace py = pybind11;

#ifdef __clang__
#define MIGRAPHX_PUSH_UNUSED_WARNING \
    _Pragma("clang diagnostic push") \
        _Pragma("clang diagnostic ignored \"-Wused-but-marked-unused\"")
#define MIGRAPHX_POP_WARNING _Pragma("clang diagnostic pop")
#else
#define MIGRAPHX_PUSH_UNUSED_WARNING
#define MIGRAPHX_POP_WARNING
#endif
#define MIGRAPHX_PYBIND11_MODULE(...) \
    MIGRAPHX_PUSH_UNUSED_WARNING      \
    PYBIND11_MODULE(__VA_ARGS__)      \
    MIGRAPHX_POP_WARNING

#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace migraphx {

migraphx::value to_value(py::kwargs kwargs);
migraphx::value to_value(py::list lst);

template <class T, class F>
void visit_py(T x, F f)
{
    if(py::isinstance<py::kwargs>(x))
    {
        f(to_value(x.template cast<py::kwargs>()));
    }
    else if(py::isinstance<py::list>(x))
    {
        f(to_value(x.template cast<py::list>()));
    }
    else if(py::isinstance<py::bool_>(x))
    {
        f(x.template cast<bool>());
    }
    else if(py::isinstance<py::int_>(x) or py::hasattr(x, "__index__"))
    {
        f(x.template cast<int>());
    }
    else if(py::isinstance<py::float_>(x))
    {
        f(x.template cast<float>());
    }
    else if(py::isinstance<py::str>(x))
    {
        f(x.template cast<std::string>());
    }
    else if(py::isinstance<migraphx::shape::dynamic_dimension>(x))
    {
        f(migraphx::to_value(x.template cast<migraphx::shape::dynamic_dimension>()));
    }
    else
    {
        MIGRAPHX_THROW("VISIT_PY: Unsupported data type!");
    }
}

migraphx::value to_value(py::list lst)
{
    migraphx::value v = migraphx::value::array{};
    for(auto val : lst)
    {
        visit_py(val, [&](auto py_val) { v.push_back(py_val); });
    }

    return v;
}

migraphx::value to_value(py::kwargs kwargs)
{
    migraphx::value v = migraphx::value::object{};

    for(auto arg : kwargs)
    {
        auto&& key = py::str(arg.first);
        auto&& val = arg.second;
        visit_py(val, [&](auto py_val) { v[key] = py_val; });
    }
    return v;
}
} // namespace migraphx

namespace pybind11 {
namespace detail {

template <>
struct npy_format_descriptor<half>
{
    static std::string format()
    {
        // following: https://docs.python.org/3/library/struct.html#format-characters
        return "e";
    }
    static constexpr auto name() { return _("half"); }
};

} // namespace detail
} // namespace pybind11

template <class F>
void visit_type(const migraphx::shape& s, F f)
{
    s.visit_type(f);
}

template <class T, class F>
void visit(const migraphx::raw_data<T>& x, F f)
{
    x.visit(f);
}

template <class F>
void visit_types(F f)
{
    migraphx::shape::visit_types(f);
}

template <class T>
py::buffer_info to_buffer_info(T& x)
{
    migraphx::shape s = x.get_shape();
    assert(s.type() != migraphx::shape::tuple_type);
    if(s.dynamic())
        MIGRAPHX_THROW("MIGRAPHX PYTHON: dynamic shape argument passed to to_buffer_info");
    auto strides = s.strides();
    std::transform(
        strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
    py::buffer_info b;
    visit_type(s, [&](auto as) {
        // migraphx use int8_t data to store bool type, we need to
        // explicitly specify the data type as bool for python
        if(s.type() == migraphx::shape::bool_type)
        {
            b = py::buffer_info(x.data(),
                                as.size(),
                                py::format_descriptor<bool>::format(),
                                s.ndim(),
                                s.lens(),
                                strides);
        }
        else
        {
            b = py::buffer_info(x.data(),
                                as.size(),
                                py::format_descriptor<decltype(as())>::format(),
                                s.ndim(),
                                s.lens(),
                                strides);
        }
    });
    return b;
}

migraphx::shape to_shape(const py::buffer_info& info)
{
    migraphx::shape::type_t t;
    std::size_t n = 0;
    visit_types([&](auto as) {
        if(info.format == py::format_descriptor<decltype(as())>::format() or
           (info.format == "l" and py::format_descriptor<decltype(as())>::format() == "q") or
           (info.format == "L" and py::format_descriptor<decltype(as())>::format() == "Q"))
        {
            t = as.type_enum();
            n = sizeof(as());
        }
        else if(info.format == "?" and py::format_descriptor<decltype(as())>::format() == "b")
        {
            t = migraphx::shape::bool_type;
            n = sizeof(bool);
        }
    });

    if(n == 0)
    {
        MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type " + info.format);
    }

    auto strides = info.strides;
    std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
        return n > 0 ? i / n : 0;
    });

    // scalar support
    if(info.shape.empty())
    {
        return migraphx::shape{t};
    }
    else
    {
        return migraphx::shape{t, info.shape, strides};
    }
}

MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
    py::class_<migraphx::shape> shape_cls(m, "shape");
    shape_cls
        .def(py::init([](py::kwargs kwargs) {
            auto v = migraphx::to_value(kwargs);
            auto t = migraphx::shape::parse_type(v.get("type", "float"));
            if(v.contains("dyn_dims"))
            {
                auto dyn_dims =
                    migraphx::from_value<std::vector<migraphx::shape::dynamic_dimension>>(
                        v.at("dyn_dims"));
                return migraphx::shape(t, dyn_dims);
            }
            auto lens = v.get<std::size_t>("lens", {1});
            if(v.contains("strides"))
                return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
            else
                return migraphx::shape(t, lens);
        }))
        .def("type", &migraphx::shape::type)
        .def("lens", &migraphx::shape::lens)
        .def("strides", &migraphx::shape::strides)
        .def("ndim", &migraphx::shape::ndim)
        .def("elements", &migraphx::shape::elements)
        .def("bytes", &migraphx::shape::bytes)
        .def("type_string", &migraphx::shape::type_string)
        .def("type_size", &migraphx::shape::type_size)
        .def("dyn_dims", &migraphx::shape::dyn_dims)
        .def("packed", &migraphx::shape::packed)
        .def("transposed", &migraphx::shape::transposed)
        .def("broadcasted", &migraphx::shape::broadcasted)
        .def("standard", &migraphx::shape::standard)
        .def("scalar", &migraphx::shape::scalar)
        .def("dynamic", &migraphx::shape::dynamic)
        .def("__eq__", std::equal_to<migraphx::shape>{})
        .def("__ne__", std::not_equal_to<migraphx::shape>{})
        .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });

    py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);

    py::class_<migraphx::shape::dynamic_dimension>(shape_cls, "dynamic_dimension")
        .def(py::init<>())
        .def(py::init<std::size_t, std::size_t>())
        .def(py::init<std::size_t, std::size_t, std::set<std::size_t>>())
        .def_readwrite("min", &migraphx::shape::dynamic_dimension::min)
        .def_readwrite("max", &migraphx::shape::dynamic_dimension::max)
        .def_readwrite("optimals", &migraphx::shape::dynamic_dimension::optimals)
        .def("is_fixed", &migraphx::shape::dynamic_dimension::is_fixed);

    py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
        .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
        .def(py::init([](py::buffer b) {
            py::buffer_info info = b.request();
            return migraphx::argument(to_shape(info), info.ptr);
        }))
        .def("get_shape", &migraphx::argument::get_shape)
        .def("data_ptr",
             [](migraphx::argument& x) { return reinterpret_cast<std::uintptr_t>(x.data()); })
        .def("tolist",
             [](migraphx::argument& x) {
                 py::list l{x.get_shape().elements()};
                 visit(x, [&](auto data) { l = py::cast(data.to_vector()); });
                 return l;
             })
        .def("__eq__", std::equal_to<migraphx::argument>{})
        .def("__ne__", std::not_equal_to<migraphx::argument>{})
        .def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });

    py::class_<migraphx::target>(m, "target");

    py::class_<migraphx::instruction_ref>(m, "instruction_ref")
        .def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
        .def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });

    py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
        .def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
        .def(
            "add_instruction",
            [](migraphx::module& mm,
               const migraphx::operation& op,
               std::vector<migraphx::instruction_ref>& args,
               std::vector<migraphx::module*>& mod_args) {
                return mm.add_instruction(op, args, mod_args);
            },
            py::arg("op"),
            py::arg("args"),
            py::arg("mod_args") = std::vector<migraphx::module*>{})
        .def(
            "add_literal",
            [](migraphx::module& mm, py::buffer data) {
                py::buffer_info info = data.request();
                auto literal_shape   = to_shape(info);
                return mm.add_literal(literal_shape, reinterpret_cast<char*>(info.ptr));
            },
            py::arg("data"))
        .def(
            "add_parameter",
            [](migraphx::module& mm, const std::string& name, const migraphx::shape shape) {
                return mm.add_parameter(name, shape);
            },
            py::arg("name"),
            py::arg("shape"))
        .def(
            "add_return",
            [](migraphx::module& mm, std::vector<migraphx::instruction_ref>& args) {
                return mm.add_return(args);
            },
            py::arg("args"))
        .def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });

    py::class_<migraphx::program>(m, "program")
        .def(py::init([]() { return migraphx::program(); }))
        .def("get_parameter_names", &migraphx::program::get_parameter_names)
        .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
        .def("get_output_shapes", &migraphx::program::get_output_shapes)
        .def("is_compiled", &migraphx::program::is_compiled)
        .def(
            "compile",
            [](migraphx::program& p,
               const migraphx::target& t,
               bool offload_copy,
               bool fast_math,
               bool exhaustive_tune) {
                migraphx::compile_options options;
                options.offload_copy    = offload_copy;
                options.fast_math       = fast_math;
                options.exhaustive_tune = exhaustive_tune;
                p.compile(t, options);
            },
            py::arg("t"),
            py::arg("offload_copy")    = true,
            py::arg("fast_math")       = true,
            py::arg("exhaustive_tune") = false)
        .def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
        .def(
            "create_module",
            [](migraphx::program& p, const std::string& name) { return p.create_module(name); },
            py::arg("name"))
        .def("run",
             [](migraphx::program& p, py::dict params) {
                 migraphx::parameter_map pm;
                 for(auto x : params)
                 {
                     std::string key      = x.first.cast<std::string>();
                     py::buffer b         = x.second.cast<py::buffer>();
                     py::buffer_info info = b.request();
                     pm[key]              = migraphx::argument(to_shape(info), info.ptr);
                 }
                 return p.eval(pm);
             })
        .def("run_async",
             [](migraphx::program& p,
                py::dict params,
                std::uintptr_t stream,
                std::string stream_name) {
                 migraphx::parameter_map pm;
                 for(auto x : params)
                 {
                     std::string key      = x.first.cast<std::string>();
                     py::buffer b         = x.second.cast<py::buffer>();
                     py::buffer_info info = b.request();
                     pm[key]              = migraphx::argument(to_shape(info), info.ptr);
                 }
                 migraphx::execution_environment exec_env{
                     migraphx::any_ptr(reinterpret_cast<void*>(stream), stream_name), true};
                 return p.eval(pm, exec_env);
             })
        .def("sort", &migraphx::program::sort)
        .def("print", [](const migraphx::program& p) { std::cout << p << std::endl; })
        .def("__eq__", std::equal_to<migraphx::program>{})
        .def("__ne__", std::not_equal_to<migraphx::program>{})
        .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });

    py::class_<migraphx::operation> op(m, "op");
    op.def(py::init([](const std::string& name, py::kwargs kwargs) {
          migraphx::value v = migraphx::value::object{};
          if(kwargs)
          {
              v = migraphx::to_value(kwargs);
          }
          return migraphx::make_op(name, v);
      }))
        .def("name", &migraphx::operation::name);

    py::enum_<migraphx::op::pooling_mode>(op, "pooling_mode")
        .value("average", migraphx::op::pooling_mode::average)
        .value("max", migraphx::op::pooling_mode::max)
        .value("lpnorm", migraphx::op::pooling_mode::lpnorm);

    py::enum_<migraphx::op::rnn_direction>(op, "rnn_direction")
        .value("forward", migraphx::op::rnn_direction::forward)
        .value("reverse", migraphx::op::rnn_direction::reverse)
        .value("bidirectional", migraphx::op::rnn_direction::bidirectional);

    m.def(
        "argument_from_pointer",
        [](const migraphx::shape shape, const int64_t address) {
            return migraphx::argument(shape, reinterpret_cast<void*>(address));
        },
        py::arg("shape"),
        py::arg("address"));

    m.def(
        "parse_tf",
        [](const std::string& filename,
           bool is_nhwc,
           unsigned int batch_size,
           std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
           std::vector<std::string> output_names) {
            return migraphx::parse_tf(
                filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
        },
        "Parse tf protobuf (default format is nhwc)",
        py::arg("filename"),
        py::arg("is_nhwc")        = true,
        py::arg("batch_size")     = 1,
        py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
        py::arg("output_names")   = std::vector<std::string>());

    m.def(
        "parse_onnx",
        [](const std::string& filename,
           unsigned int default_dim_value,
           migraphx::shape::dynamic_dimension default_dyn_dim_value,
           std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
           std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
               map_dyn_input_dims,
           bool skip_unknown_operators,
           bool print_program_on_error,
           int64_t max_loop_iterations) {
            migraphx::onnx_options options;
            options.default_dim_value      = default_dim_value;
            options.default_dyn_dim_value  = default_dyn_dim_value;
            options.map_input_dims         = map_input_dims;
            options.map_dyn_input_dims     = map_dyn_input_dims;
            options.skip_unknown_operators = skip_unknown_operators;
            options.print_program_on_error = print_program_on_error;
            options.max_loop_iterations    = max_loop_iterations;
            return migraphx::parse_onnx(filename, options);
        },
        "Parse onnx file",
        py::arg("filename"),
        py::arg("default_dim_value")     = 0,
        py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
        py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
        py::arg("map_dyn_input_dims") =
            std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
        py::arg("skip_unknown_operators") = false,
        py::arg("print_program_on_error") = false,
        py::arg("max_loop_iterations")    = 10);

    m.def(
        "parse_onnx_buffer",
        [](const std::string& onnx_buffer,
           unsigned int default_dim_value,
           migraphx::shape::dynamic_dimension default_dyn_dim_value,
           std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
           std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
               map_dyn_input_dims,
           bool skip_unknown_operators,
           bool print_program_on_error) {
            migraphx::onnx_options options;
            options.default_dim_value      = default_dim_value;
            options.default_dyn_dim_value  = default_dyn_dim_value;
            options.map_input_dims         = map_input_dims;
            options.map_dyn_input_dims     = map_dyn_input_dims;
            options.skip_unknown_operators = skip_unknown_operators;
            options.print_program_on_error = print_program_on_error;
            return migraphx::parse_onnx_buffer(onnx_buffer, options);
        },
        "Parse onnx file",
        py::arg("filename"),
        py::arg("default_dim_value")     = 0,
        py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
        py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
        py::arg("map_dyn_input_dims") =
            std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
        py::arg("skip_unknown_operators") = false,
        py::arg("print_program_on_error") = false);

    m.def(
        "load",
        [](const std::string& name, const std::string& format) {
            migraphx::file_options options;
            options.format = format;
            return migraphx::load(name, options);
        },
        "Load MIGraphX program",
        py::arg("filename"),
        py::arg("format") = "msgpack");

    m.def(
        "save",
        [](const migraphx::program& p, const std::string& name, const std::string& format) {
            migraphx::file_options options;
            options.format = format;
            return migraphx::save(p, name, options);
        },
        "Save MIGraphX program",
        py::arg("p"),
        py::arg("filename"),
        py::arg("format") = "msgpack");

    m.def("get_target", &migraphx::make_target);
    m.def("create_argument", [](const migraphx::shape& s, const std::vector<double>& values) {
        if(values.size() != s.elements())
            MIGRAPHX_THROW("Values and shape elements do not match");
        migraphx::argument a{s};
        a.fill(values.begin(), values.end());
        return a;
    });
    m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
    m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
    m.def("quantize_fp16",
          &migraphx::quantize_fp16,
          py::arg("prog"),
          py::arg("ins_names") = std::vector<std::string>{"all"});
    m.def("quantize_int8",
          &migraphx::quantize_int8,
          py::arg("prog"),
          py::arg("t"),
          py::arg("calibration") = std::vector<migraphx::parameter_map>{},
          py::arg("ins_names")   = std::vector<std::string>{"dot", "convolution"});

#ifdef HAVE_GPU
    m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
    m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
    m.def("from_gpu", &migraphx::gpu::from_gpu);
    m.def("gpu_sync", [] { migraphx::gpu::gpu_sync(); });
#endif

#ifdef VERSION_INFO
    m.attr("__version__") = VERSION_INFO;
#else
    m.attr("__version__") = "dev";
#endif
}