/*******************************************************************************
* Copyright (C) 2024 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       SYCL USM API oneapi::mkl::sparse::trsm to perform
*       sparse triangular solve on a SYCL device (CPU, GPU) with a dense matrix
*       right-hand side/multiple righ-hand side vectors. This example uses a
*       sparse matrix in CSR format.
*
*       op(A) * Y = alpha * op(X)
*
*       where op() is defined by one of
*           oneapi::mkl::transpose::{nontrans,trans,conjtrans}
*
*       NOTE: currently, only op() = nontrans is supported for TRSM
*
*       The supported floating point data types for trsm are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*
*       The supported matrix formats for trsm are:
*           CSR
*           COO (currently only on CPU device)
*
*******************************************************************************/

// stl includes
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <limits>
#include <list>
#include <vector>

#include "mkl.h"
#include "oneapi/mkl.hpp"
#include <sycl/sycl.hpp>

// local includes
#include "common_for_examples.hpp"
#include "./include/common_for_sparse_examples.hpp"

template<typename intType>
static inline intType get_dense_matrix_index(const oneapi::mkl::layout layout_val,
                                             const intType row,
                                             const intType col,
                                             const intType ld)
{
    return (layout_val == oneapi::mkl::layout::row_major ? (row * ld + col) : (row + col * ld));
}

//
// Main example for Sparse Triangular Solver consisting of
// initialization of A matrix, B and X matrices.
// Then the following system is solved
//
// op(A) * X = alpha * op(B)
//
// and finally the results are post processed.
//
template <typename dataType, typename intType>
int run_sparse_blas_example(sycl::queue &q)
{
    bool good = true;

    //
    // handle for sparse matrix
    //
    oneapi::mkl::sparse::matrix_handle_t csrA = nullptr;

    //
    // Arrays for handling deallocation
    //
    std::vector<intType *> int_ptr_vec;
    std::vector<dataType *> data_ptr_vec;

    try {

        // Initialize data for Sparse Triangular Solve with multiple RHS
        auto diag_val       = oneapi::mkl::diag::nonunit;
        auto index_base_val = oneapi::mkl::index_base::zero;
        auto uplo_val       = oneapi::mkl::uplo::lower;
        auto layout_val     = oneapi::mkl::layout::col_major;
        if constexpr (is_complex<dataType>()) {
            diag_val       = oneapi::mkl::diag::unit;
            index_base_val = oneapi::mkl::index_base::one;
            uplo_val       = oneapi::mkl::uplo::upper;
            layout_val     = oneapi::mkl::layout::row_major;
        }
        intType int_index = (index_base_val == oneapi::mkl::index_base::zero ? 0 : 1);

        auto opA        = oneapi::mkl::transpose::nontrans;
        auto opB        = oneapi::mkl::transpose::nontrans;

        dataType alpha;
        if constexpr (!is_complex<dataType>()) {
            alpha = set_fp_value(dataType(2.0), dataType(0.0));
        }
        else {
            alpha = set_fp_value(dataType(2.0), dataType(-1.0));
        }

        // Matrix data size
        intType size  = 4;
        const std::int64_t  nrows = size * size * size;
        const std::int64_t  ncols = nrows;
        intType columns = 5;
        intType ldb    = (layout_val == oneapi::mkl::layout::row_major) ? columns     : nrows;
        intType ldx    = (layout_val == oneapi::mkl::layout::row_major) ? columns     : nrows;
        intType b_size = (layout_val == oneapi::mkl::layout::row_major) ? nrows * ldb : ldb * columns;
        intType x_size = (layout_val == oneapi::mkl::layout::row_major) ? nrows * ldx : ldx * columns;

        // Input matrix in CSR format
        intType *ia, *ja;
        dataType *a, *dnsB, *dnsX, *dnsX_ref;
        const std::int64_t sizea   = 27 * nrows;
        const std::int64_t sizeja  = 27 * nrows;
        const std::int64_t sizeia  = nrows + 1;

        ia        = sycl::malloc_shared<intType>(sizeia, q);
        ja        = sycl::malloc_shared<intType>(sizeja, q);
        a         = sycl::malloc_shared<dataType>(sizea, q);
        dnsB      = sycl::malloc_shared<dataType>(b_size, q);
        dnsX      = sycl::malloc_shared<dataType>(x_size, q);
        dnsX_ref  = sycl::malloc_shared<dataType>(x_size, q);

        if (!ia || !ja || !a || !dnsB || !dnsX || !dnsX_ref) {
            std::string errorMessage =
                "Failed to allocate USM shared memory arrays \n"
                " for CSR A matrix: ia(" + std::to_string((sizeia)*sizeof(intType)) + " bytes)\n"
                "                   ja(" + std::to_string((sizeja)*sizeof(intType)) + " bytes)\n"
                "                   a(" + std::to_string((sizea)*sizeof(dataType)) + " bytes)\n"
                " and vectors:      dnsB(" + std::to_string((b_size)*sizeof(dataType)) + " bytes)\n"
                "                   dnsX(" + std::to_string((x_size)*sizeof(dataType)) + " bytes)\n"
                "                   dnsX_ref(" + std::to_string((x_size)*sizeof(dataType)) + " bytes)";

            throw std::runtime_error(errorMessage);
        }
        int_ptr_vec.push_back(ia);
        int_ptr_vec.push_back(ja);
        data_ptr_vec.push_back(a);
        data_ptr_vec.push_back(dnsB);
        data_ptr_vec.push_back(dnsX);
        data_ptr_vec.push_back(dnsX_ref);


        generate_sparse_matrix<dataType, intType>(size, ia, ja, a, int_index);

        const std::int64_t nnz = ia[nrows] - int_index;

        // Init matrices B, X, and X_ref
        for (int i = 0; i < b_size; i++) {
            dnsB[i] = set_fp_value(dataType(2.0), dataType(0.0));
        }
        for (int i = 0; i < x_size; i++) {
            dnsX[i] = set_fp_value(dataType(0.0), dataType(0.0));
            dnsX_ref[i] = set_fp_value(dataType(0.0), dataType(0.0));
        }

        //
        // Execute Triangular Solve
        //
        std::cout << "\n\t\tsparse::trsm parameters: solve for X: op(uplo(A)) * X = alpha * op(B)\n";
        std::cout << "\t\t\tlayout   = " << layout_val << std::endl;
        std::cout << "\t\t\topA      = " << opA << std::endl;
        std::cout << "\t\t\topB      = " << opB << std::endl;
        std::cout << "\t\t\tindexing = " << index_base_val << std::endl;
        std::cout << "\t\t\tuplo     = " << uplo_val << std::endl;
        std::cout << "\t\t\tdiag     = " << diag_val << std::endl;
        std::cout << "\t\t\talpha    = " << alpha << std::endl;
        std::cout << "\t\t\tnrows    = " << nrows << std::endl;
        std::cout << "\t\t\tncols    = " << ncols << std::endl;
        std::cout << "\t\t\tnnz      = " << nnz << std::endl;
        std::cout << "\t\t\tcolumns  = " << columns << std::endl;
        std::cout << "\t\t\tldb      = " << ldb << std::endl;
        std::cout << "\t\t\tldx      = " << ldx << std::endl;

        oneapi::mkl::sparse::init_matrix_handle(&csrA);

        auto ev_set = oneapi::mkl::sparse::set_csr_data(q, csrA, nrows, nrows, nnz,
                                                        index_base_val, ia, ja, a, {});

        auto ev_opt = oneapi::mkl::sparse::optimize_trsm(q, layout_val, uplo_val, opA,
                                                         diag_val, csrA, columns, {ev_set});

        // add oneapi::mkl::sparse::trsm to execution queue
        auto ev_trsm = oneapi::mkl::sparse::trsm(q, layout_val, opA, opB, uplo_val, diag_val,
                                                 alpha, csrA, dnsB, columns, ldb, dnsX, ldx, {ev_opt});

        auto ev_release = oneapi::mkl::sparse::release_matrix_handle(q, &csrA, {ev_trsm});

        ev_release.wait();

        //
        // Post Processing
        //

        // Solve for dnsX_ref in  op(uplo(A)) * dnsX_ref = op(dnsB)  with op() == nontrans
        dataType *res = dnsX;
        if (uplo_val == oneapi::mkl::uplo::lower) {
            for (intType row = 0; row < nrows; ++row) {
                std::vector<dataType> tmp_vec(columns);
                for (intType col = 0; col < columns; ++col) {
                    intType b_idx = get_dense_matrix_index(layout_val, row, col, ldb);
                    tmp_vec[col] = alpha * dnsB[b_idx];
                }
                dataType fp_diag_val = set_fp_value(dataType(0.0), dataType(0.0));

                for (intType i = ia[row] - int_index; i < ia[row + 1] - int_index; i++) {
                    intType colA = ja[i] - int_index;
                    if (colA < row) {
                        for (intType col = 0; col < columns; ++col) {
                            intType x_idx = get_dense_matrix_index(layout_val, colA, col, ldx);
                            tmp_vec[col] -= a[i] * dnsX_ref[x_idx];
                        }
                    }
                    else if (colA == row) {
                        fp_diag_val = a[i];
                    }
                }

                for (intType col = 0; col < columns; ++col) {
                    intType x_out_idx = get_dense_matrix_index(layout_val, row, col, ldx);
                    dnsX_ref[x_out_idx] = (diag_val == oneapi::mkl::diag::unit) ? tmp_vec[col] : tmp_vec[col] / fp_diag_val;
                }
            }
        }
        else { // uplo::upper
            for (intType row = nrows-1; row >= 0; --row) {
                std::vector<dataType> tmp_vec(columns);
                for (intType col = 0; col < columns; ++col) {
                    intType b_idx = get_dense_matrix_index(layout_val, row, col, ldb);
                    tmp_vec[col] = alpha * dnsB[b_idx];
                }
                dataType fp_diag_val = set_fp_value(dataType(0.0), dataType(0.0));

                for (intType i = ia[row] - int_index; i < ia[row + 1] - int_index; i++) {
                    intType colA = ja[i] - int_index;
                    if (colA > row) {
                        for (intType col = 0; col < columns; ++col) {
                            intType x_idx = get_dense_matrix_index(layout_val, colA, col, ldx);
                            tmp_vec[col] -= a[i] * dnsX_ref[x_idx];
                        }
                    }
                    else if (colA == row) {
                        fp_diag_val = a[i];
                    }
                }

                for (intType col = 0; col < columns; ++col) {
                    intType x_out_idx = get_dense_matrix_index(layout_val, row, col, ldx);
                    dnsX_ref[x_out_idx] = (diag_val == oneapi::mkl::diag::unit) ? tmp_vec[col] : tmp_vec[col] / fp_diag_val;
                }
            }
        }

        // check for correctness of X and X_ref
        for (intType i = 0; i < x_size; i++) {
            good &= check_result(res[i], dnsX_ref[i], /*scale factor */ nrows * 2, i);
        }

        std::cout << "\n\t\t sparse::trsm example " << (good ? "passed" : "failed")
                  << "\n\tFinished" << std::endl;


        q.wait_and_throw();

    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception:\n" << e.what() << std::endl;
        good = false;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception:\n" << e.what() << std::endl;
        good = false;
    }

    q.wait();

    // backup cleaning of matrix handle and others for if exceptions happened
    oneapi::mkl::sparse::release_matrix_handle(q, &csrA).wait();

    cleanup_arrays<dataType, intType>(data_ptr_vec, int_ptr_vec, q);

    q.wait();

    return good ? 0 : 1;
}

//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner()
{

    std::cout << "" << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << "# Sparse Triangular Solve with multiple RHS Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Y = alpha * op(A)^(-1) * op(X)" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A is a sparse matrix in CSR format, X and Y are "
                 "dense matrices"
              << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   sparse::trsm" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//
//  For each device selected and each supported data type,
//  run_sparse_blas_example is run with all supported data types,
//  if any fail, we move on to the next device.
//

int main(int argc, char **argv)
{

    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int status = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        try {
            sycl::device my_dev;
            bool my_dev_is_found = false;
            get_sycl_device(my_dev, my_dev_is_found, *it);

            if (my_dev_is_found) {
                std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

                // Catch asynchronous exceptions
                auto exception_handler = [](sycl::exception_list exceptions) {
                    for (std::exception_ptr const &e : exceptions) {
                        try {
                            std::rethrow_exception(e);
                        }
                        catch (sycl::exception const &e) {
                            std::cout << "Caught asynchronous SYCL exception: \n"
                                << e.what() << std::endl;
                        }
                    }
                };

                sycl::queue q(my_dev, exception_handler);

                std::cout << "\tRunning with single precision real data type:" << std::endl;
                status |= run_sparse_blas_example<float, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision real data type:" << std::endl;
                    status |= run_sparse_blas_example<double, std::int32_t>(q);
                }

                std::cout << "\tRunning with single precision complex data type:" << std::endl;
                status |= run_sparse_blas_example<std::complex<float>, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision complex data type:" << std::endl;
                    status |= run_sparse_blas_example<std::complex<double>, std::int32_t>(q);
                }

            }
            else {
#ifdef FAIL_ON_MISSING_DEVICES
                std::cout << "No " << sycl_device_names[*it]
                    << " devices found; Fail on missing devices "
                    "is enabled.\n";
                return 1;
#else
                std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                    << sycl_device_names[*it] << " tests.\n";
#endif
            }
        }
        catch (sycl::exception const &e) {
            std::cout << "\t\tCaught SYCL exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }
        catch (std::exception const &e) {
            std::cout << "\t\tCaught std exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }


    } // for loop over devices

    mkl_free_buffers();
    return status;
}
