/*******************************************************************************
* Copyright (C) 2024 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       DPCPP sycl::buffer API oneapi::mkl::sparse::omatadd to perform general
*       sparse matrix-sparse matrix addition on a SYCL device (CPU, GPU).
*       This example uses matrices in CSR format.
*
*           C = alpha * op(A) + beta * op(B)
*
*       where op() is defined by one of
*           oneapi::mkl::transpose::{nontrans,trans,conjtrans}
*
*       The supported floating point data types for omatadd matrix data are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*
*       The supported matrix format for omatadd is:
*           CSR + CSR = CSR
*
*******************************************************************************/

// stl includes
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <limits>
#include <list>
#include <vector>

#include "mkl.h"
#include "oneapi/mkl.hpp"
#include <sycl/sycl.hpp>

// local includes
#include "common_for_examples.hpp"
#include "./include/common_for_sparse_examples.hpp"

//
// Main example for Sparse Matrix-Sparse Matrix Addition consisting of
// initialization of A and B matrices through process of creating C matrix as
// the sum,
//
// C = alpha * op(A) + beta * op(B)
//
template <typename dataType, typename intType>
int run_sparse_blas_example(sycl::queue &q)
{
    bool good = true;

    // handles for sparse matrix and descriptor
    oneapi::mkl::sparse::matrix_handle_t csrA = nullptr;
    oneapi::mkl::sparse::matrix_handle_t csrB = nullptr;
    oneapi::mkl::sparse::matrix_handle_t csrC = nullptr;

    // workspace and descriptor for omatadd
    sycl::buffer<std::uint8_t> *p_tempWorkspace = nullptr;
    oneapi::mkl::sparse::omatadd_descr_t descr = nullptr;

    try {

        // Initialize data for Sparse Matrix - Sparse Matrix Addition
        auto opA = oneapi::mkl::transpose::nontrans;
        auto opB = oneapi::mkl::transpose::nontrans;

        auto a_index = oneapi::mkl::index_base::zero;
        auto b_index = oneapi::mkl::index_base::zero;
        auto c_index = oneapi::mkl::index_base::zero;

        dataType alpha = set_fp_value(dataType(2.0), dataType(0.0));
        dataType beta  = set_fp_value(dataType(1.0), dataType(0.0));

        //
        // set up dimensions of matrix addition
        //
        intType size = 4;

        const std::int64_t a_nrows = size * size * size;
        const std::int64_t a_ncols = a_nrows;
              std::int64_t a_nnz   = 27 * a_nrows;
        const std::int64_t b_nrows = (opA == opB) ? a_nrows : a_ncols;
        const std::int64_t b_ncols = (opA == opB) ? a_ncols : a_nrows;
              std::int64_t b_nnz   = 0; // b_nnz is unknown at this point; reset once random matrix is generated
        const std::int64_t c_nrows = (opA == oneapi::mkl::transpose::nontrans) ? a_nrows : a_ncols;
        const std::int64_t c_ncols = (opA == oneapi::mkl::transpose::nontrans) ? a_ncols : a_nrows;
        // c_nnz is unknown at this point

        //
        // setup A data locally in CSR format
        //
        std::vector<intType, mkl_allocator<intType, 64>> ia;
        std::vector<intType, mkl_allocator<intType, 64>> ja;
        std::vector<dataType, mkl_allocator<dataType, 64>> a;

        ia.resize(a_nrows + 1);
        ja.resize(27 * a_nrows);
        a.resize(27 * a_nrows);

        intType a_ind = a_index == oneapi::mkl::index_base::zero ? 0 : 1;
        generate_sparse_matrix<dataType, intType>(size, ia, ja, a, a_ind);
        a_nnz = ia[a_nrows] - a_ind;

        //
        // setup B data locally in CSR format
        //
        std::vector<intType, mkl_allocator<intType, 64>> ib;
        std::vector<intType, mkl_allocator<intType, 64>> jb;
        std::vector<dataType, mkl_allocator<dataType, 64>> b;

        intType b_ind = b_index == oneapi::mkl::index_base::zero ? 0 : 1;
        generate_random_sparse_matrix<dataType, intType>(b_nrows, b_ncols, 0.1 /*density_val*/, ib, jb, b, b_ind);
        b_nnz = ib[b_nrows] - b_ind;

        //
        // Execute Matrix Addition
        //

        std::cout << "\n\t\tsparse::omatadd parameters:\n";
        std::cout << "\t\t\topA = " << opA << std::endl;
        std::cout << "\t\t\topB = " << opB << std::endl;

        std::cout << "\t\t\tA_nrows = A_ncols = " << a_nrows << std::endl;
        std::cout << "\t\t\tB_nrows = B_ncols = " << b_nrows << std::endl;
        std::cout << "\t\t\tC_nrows = C_ncols = " << c_nrows << std::endl;

        std::cout << "\t\t\tA_index = " << a_index << std::endl;
        std::cout << "\t\t\tB_index = " << b_index << std::endl;
        std::cout << "\t\t\tC_index = " << c_index << std::endl;



        sycl::buffer<intType, 1> a_rowptr(ia.data(), ia.data() + a_nrows + 1);
        sycl::buffer<intType, 1> a_colind(ja.data(), ja.data() + a_nnz);
        sycl::buffer<dataType, 1> a_val(a.data(), a.data() + a_nnz);
        sycl::buffer<intType, 1> b_rowptr(ib.data(), ib.data() + b_nrows + 1);
        sycl::buffer<intType, 1> b_colind(jb.data(), jb.data() + b_nnz);
        sycl::buffer<dataType, 1> b_val(b.data(), b.data() + b_nnz);

        oneapi::mkl::sparse::init_matrix_handle(&csrA);
        oneapi::mkl::sparse::init_matrix_handle(&csrB);
        oneapi::mkl::sparse::init_matrix_handle(&csrC);

        oneapi::mkl::sparse::set_csr_data(q, csrA, a_nrows, a_ncols, a_nnz, a_index, a_rowptr, a_colind, a_val);
        oneapi::mkl::sparse::set_csr_data(q, csrB, b_nrows, b_ncols, b_nnz, b_index, b_rowptr, b_colind, b_val);

        // Step 1.1: Create 0-sized dummy args for now to matrix handle
        sycl::buffer<intType, 1> c_rowptr_dummy(0);
        sycl::buffer<intType, 1> c_colind_dummy(0);
        sycl::buffer<dataType, 1> c_vals_dummy(0);
        oneapi::mkl::sparse::set_csr_data(q, csrC, c_nrows, c_ncols, 0, c_index,
                                          c_rowptr_dummy, c_colind_dummy, c_vals_dummy);

        // Step 1.2: Initialize the omatadd descriptor and set algorithm
        oneapi::mkl::sparse::omatadd_alg alg = oneapi::mkl::sparse::omatadd_alg::default_alg;
        oneapi::mkl::sparse::init_omatadd_descr(q, &descr);

        // Step 2.1: Query for size of temporary workspace
        std::int64_t sizeTempWorkspace = 0;
        oneapi::mkl::sparse::omatadd_buffer_size(q, opA, opB, csrA, csrB, csrC, alg, descr, sizeTempWorkspace);

        // Step 2.2: Allocate temporary workspace
        p_tempWorkspace = new sycl::buffer<std::uint8_t, 1>(sizeTempWorkspace);
        if (!p_tempWorkspace) {
            std::string message = "Failed to allocate " + std::to_string(sizeTempWorkspace)
                + " bytes in p_tempWorkspace(new sycl::buffer<uint8_t, 1>).";
            throw std::runtime_error(message);
        }

        // Step 3.1: Analyze sparity patterns of A and B to count non-zeros in C
        oneapi::mkl::sparse::omatadd_analyze(q, opA, opB, csrA, csrB, csrC, alg, descr, p_tempWorkspace);

        // Step 3.2: Get non-zero count of output C matrix
        std::int64_t c_nnz = 0;
        oneapi::mkl::sparse::omatadd_get_nnz(q, opA, opB, csrA, csrB, csrC, alg, descr, c_nnz);

        // Step 4.1: Allocate final c matrix arrays
        sycl::buffer<intType, 1> c_rowptr(c_nrows + 1);
        sycl::buffer<intType, 1> c_colind(c_nnz);
        sycl::buffer<dataType, 1> c_vals(c_nnz);
        oneapi::mkl::sparse::set_csr_data(q, csrC, c_nrows, c_ncols, c_nnz, c_index, c_rowptr, c_colind, c_vals);

        // Step 4.2: Finalize by performing matrix addition into C matrix handle
        oneapi::mkl::sparse::omatadd(q, opA, opB, alpha, csrA, beta, csrB, csrC, alg, descr);

        // Step 4.3: Release omatadd descriptor and temporary workspace
        oneapi::mkl::sparse::release_omatadd_descr(q, descr); descr = nullptr;
        // note that deletion of a buffer in this way is a blocking action and triggers
        // all kernels using accessors on p_tempWorkspace to be waited on before proceeding,
        // this can be moved to later if more asynchronous parallelism is desired
        if (p_tempWorkspace) { delete p_tempWorkspace; p_tempWorkspace = nullptr; }

        // Step 5 (Optional): Sort C matrix output if desired
        oneapi::mkl::sparse::sort_matrix(q, csrC);

        // Clean up
        oneapi::mkl::sparse::release_matrix_handle(q, &csrA);
        oneapi::mkl::sparse::release_matrix_handle(q, &csrB);
        oneapi::mkl::sparse::release_matrix_handle(q, &csrC);

        //
        // Post Processing
        //

        // Print part of C matrix solution
        {
            auto ic       = c_rowptr.get_host_access(sycl::read_only);
            auto jc       = c_colind.get_host_access(sycl::read_only);
            auto c        = c_vals.get_host_access(sycl::read_only);
            intType c_ind = c_index == oneapi::mkl::index_base::zero ? 0 : 1;
            std::cout << "C matrix [first two rows]:" << std::endl;
            for (intType row = 0; row < std::min(static_cast<std::int64_t>(2), c_nrows); ++row) {
                for (intType j = ic[row] - c_ind; j < ic[row + 1] - c_ind; ++j) {
                    intType col = jc[j];
                    dataType val  = c[j];
                    std::cout << "C(" << row + c_ind << ", " << col << ") = " << val << std::endl;
                }
            }
        }

        good = true;

        q.wait_and_throw();
    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception:\n" << e.what() << std::endl;
        good = false;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception:\n" << e.what() << std::endl;
        good = false;
    }

    q.wait();

    // backup cleaning of matrix handle and others for if exceptions happened
    if (descr) { oneapi::mkl::sparse::release_omatadd_descr(q, descr); descr = nullptr; }
    if (p_tempWorkspace) { delete p_tempWorkspace; p_tempWorkspace = nullptr; }
    if (csrA) oneapi::mkl::sparse::release_matrix_handle(q, &csrA);
    if (csrB) oneapi::mkl::sparse::release_matrix_handle(q, &csrB);
    if (csrC) oneapi::mkl::sparse::release_matrix_handle(q, &csrC);

    q.wait();

    return good ? 0 : 1;
}

//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner()
{

    std::cout << "" << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << "# Sparse Matrix-Sparse Matrix Addition Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "#    C = alpha * op(A) + beta * op(B)" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A and B are sparse matrices in CSR format, and C is the\n"
                 "# output of sparse matrix addition in CSR format"
              << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   sparse::omatadd_buffer_size" << std::endl;
    std::cout << "#   sparse::omatadd_analyze" << std::endl;
    std::cout << "#   sparse::omatadd_get_nnz" << std::endl;
    std::cout << "#   sparse::omatadd" << std::endl;
    std::cout << "#   sparse::init_omatadd_descr" << std::endl;
    std::cout << "#   sparse::release_omatadd_descr" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "#   sparse::init_matrix_handle" << std::endl;
    std::cout << "#   sparse::set_csr_data" << std::endl;
    std::cout << "#   sparse::release_matrix_handle" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//
//  For each device selected and each supported data type,
//  run_sparse_blas_example() is run with all supported data types,
//  if any fail, we move on to the next device.
//

int main(int argc, char **argv)
{
    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int status = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        try {
            sycl::device my_dev;
            bool my_dev_is_found = false;
            get_sycl_device(my_dev, my_dev_is_found, *it);

            if (my_dev_is_found) {
                std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

                // Catch asynchronous exceptions
                auto exception_handler = [](sycl::exception_list exceptions) {
                    for (std::exception_ptr const &e : exceptions) {
                        try {
                            std::rethrow_exception(e);
                        }
                        catch (sycl::exception const &e) {
                            std::cout << "Caught asynchronous SYCL exception: \n"
                                << e.what() << std::endl;
                        }
                    }
                };

                sycl::queue q(my_dev, exception_handler);

                std::cout << "\tRunning with single precision real data type:" << std::endl;
                status |= run_sparse_blas_example<float, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision real data type:" << std::endl;
                    status |= run_sparse_blas_example<double, std::int32_t>(q);
                }

                std::cout << "\tRunning with single precision complex data type:" << std::endl;
                status |= run_sparse_blas_example<std::complex<float>, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision complex data type:" << std::endl;
                    status |= run_sparse_blas_example<std::complex<double>, std::int32_t>(q);
                }

            }
            else {
#ifdef FAIL_ON_MISSING_DEVICES
                std::cout << "No " << sycl_device_names[*it]
                    << " devices found; Fail on missing devices "
                    "is enabled.\n";
                return 1;
#else
                std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                    << sycl_device_names[*it] << " tests.\n";
#endif
            }
        }
        catch (sycl::exception const &e) {
            std::cout << "\t\tCaught SYCL exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }
        catch (std::exception const &e) {
            std::cout << "\t\tCaught std exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }

    } // for device

    mkl_free_buffers();
    return status;
}
