/*******************************************************************************
* Copyright (C) 2014 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file ComputeSYMGS.cpp

 HPCG routine
 */

#include "ComputeSYMGS.hpp"
#include "ComputeSYMGS_ref.hpp"
#include "UsmUtil.hpp"

#ifndef HPCG_NO_MPI
#include "ExchangeHalo.hpp"
#include <mpi.h>
#include "Geometry.hpp"
#include <cstdlib>
#endif

#include "VeryBasicProfiler.hpp"

#ifdef BASIC_PROFILING
#define BEGIN_PROFILE(n) optData->profiler->begin(n);
#define END_PROFILE(n) optData->profiler->end(n);
#define END_PROFILE_WAIT(n, event) event.wait(); optData->profiler->end(n);
#else
#define BEGIN_PROFILE(n)
#define END_PROFILE(n)
#define END_PROFILE_WAIT(n, event)
#endif


// implementations of SYMGS and SYMGS_MV using oneMKL or custom kernels in
// HPCG repository for the Sparse BLAS operations
//
// SYMGS Formulation Mathematical:
//
// Given A=(L+D+U+B), where B are nonlocal parts, x0 and b as inputs, find x:
// - Halo exchange x0 to get x0_nonloc
// - Solve for x1:  (L+D) * x1 = b - B * x0_nonloc - U * x0
// - Solve for x:   (D+U) * x  = b - B * x0_nonloc - L * x1
//
//
// SYMGS_MV Formulation Mathematical:
//
// Given A=(L+D+U+B) where B are nonlocal parts, x0=0 and b as inputs, find x and y:
// - Solve for x1:  (L+D) * x1 = b
// - Solve for x:   (D+U) * x  = b - L * x1
// - Halo exchange x to get x0_nonloc
// - SpMV: y = A * x + B * x_nonloc

// SYMGS Formulation 2:
//
// Given A=(L+D+U+B), where B are nonlocal parts, x0 and b as inputs, find x:
// - Halo exchange x0 to get x0_nonloc
// - Construct RHS1: rhs1 = b - B * x0_nonloc - U * x0_loc
// - Solve for x1:   (L+D) * x1 = rhs1
// - Construct RHS2: rhs2 = b - B * x0_nonloc - L * x1
// - Solve for x:    (D+U) * x  = rhs2
sycl::event run_SYMGS_custom(sycl::queue &main_queue,
                             const SparseMatrix & A,
                             struct optData *optData,
                             custom::sparseMatrix *sparseM,
                             const Vector &r,
                             Vector &x,
                             const std::vector<sycl::event> &deps)
{
    sycl::event result_ev;

#ifndef HPCG_NO_MPI
    // exchange x.values with MPI neighbors
    BEGIN_PROFILE("SYMGS:halo");
    sycl::event halo_ev = ExchangeHalo(A, x, main_queue, deps);
    std::vector<sycl::event> halo_deps({halo_ev});
    END_PROFILE_WAIT("SYMGS:halo", halo_ev);
#else
    const std::vector<sycl::event>& halo_deps = deps;
#endif

    // dtmp <- r.values - (U + B) * x.values  ( rhs1(dtmp) = r - U*x0 - B*x0 )
    // dtmp2 <- U * x.values                  ( dtmp2 = U*x0                 )
    BEGIN_PROFILE("SYMGS:trmvRUB");
    auto trmvRU_ev = custom::SpTRMV(main_queue, sparseM, custom::uplo::upper_nonlocal,
                                  x.values, r.values, optData->dtmp, optData->dtmp2, halo_deps);
    END_PROFILE_WAIT("SYMGS:trmvRUB", trmvRU_ev);

    // dtmp, dtmp2 are both overwritten through kernel process.
    // On Input:
    //     dtmp_in = r.values - (U + B) * x.values
    //     dtmp2_in = U * x0
    // On Output:
    //     dtmp2_out <- (L + D)^{-1} * dtmp_in      ( x1(=dtmp2_out) = (L+D)\rhs1 )
    //     dtmp_out <- dtmp2_in + diag * dtmp2_out  ( rhs2(=dtmp_out) = D*dtmp2_out + U*x0 = r - B*x0 - L*x1 )
    BEGIN_PROFILE("SYMGS:trsvL");
    auto trsvL_ev = custom::SpTRSV_FUSED(main_queue, sparseM, custom::uplo::lower_diagonal,
            optData->dtmp, optData->dtmp2, {trmvRU_ev});
    END_PROFILE_WAIT("SYMGS:trsvL", trsvL_ev);

    // x.values <- (D + U)^{-1} * dtmp   (x = (D+U) \ rhs2 )
    BEGIN_PROFILE("SYMGS:trsvU");
    result_ev = custom::SpTRSV(main_queue, sparseM, custom::uplo::upper_diagonal,
            optData->dtmp, x.values, {trsvL_ev});
    END_PROFILE_WAIT("SYMGS:trsvU", result_ev);

    return result_ev;

} // run_SYMGS_custom


// SYMGS_MV Formulation 2:
//
// Given A=(L+D+U+B) where B are nonlocal parts, x0=0 and b as inputs, find x and y:
// - Solve for x1:   (L+D) * x1 = b
// - Construct RHS2: rhs2 = b - L * t = D * t (from step 1)
// - Solve for x:    (D+U) * x  = rhs2
// - Halo exchange x to get x0_nonloc
// - Construct y: y = (L + D + U) * x + B * x_nonloc
sycl::event run_SYMGS_MV_custom(sycl::queue &main_queue,
                                const SparseMatrix & A,
                                struct optData *optData,
                                custom::sparseMatrix *sparseM,
                                const Vector &r,
                                Vector &x,
                                Vector &y,
                                const std::vector<sycl::event> &deps)
{
    // y.values <- (L + D)^{-1} * r.values
    BEGIN_PROFILE("SYMGS_MV:trsvL");
    auto trsvL_ev = custom::SpTRSV(main_queue, sparseM, custom::uplo::lower_diagonal,
            r.values, y.values, deps);
    END_PROFILE_WAIT("SYMGS_MV:trsvL", trsvL_ev);

    // y.values <- diags * y.values
    // x.values <- (D + U)^{-1} * y.values
    BEGIN_PROFILE("SYMGS_MV:trsvU");
    auto trsvU_ev = custom::SpTRSV_FUSED(main_queue, sparseM, custom::uplo::upper_diagonal,
            y.values, x.values, {trsvL_ev});
    END_PROFILE_WAIT("SYMGS_MV:trsvU", trsvU_ev);

#ifndef HPCG_NO_MPI
    // x.values exchange with neighbors
    BEGIN_PROFILE("SYMGS_MV:halo");
    sycl::event halo_ev = ExchangeHalo(A, x, main_queue, {trsvU_ev});
    END_PROFILE_WAIT("SYMGS_MV:halo", halo_ev);
#else
    sycl::event halo_ev = trsvU_ev;
#endif

    // y.values <- (L + B) * x.values + y.values
    BEGIN_PROFILE("SYMGS_MV:trmvL");
    auto result_ev = custom::SpTRMV(main_queue, sparseM, custom::uplo::lower_update,
            x.values, r.values/*Not used*/, y.values, optData->dtmp/*Not used*/, {halo_ev});
    END_PROFILE_WAIT("SYMGS_MV:trmvL", result_ev);

    return result_ev;

} // run_SYMGS_MV_custom


/*!
  Routine to compute one step of symmetric Gauss-Seidel:

  Assumption about the structure of matrix A:
  - Each row 'i' of the matrix has nonzero diagonal value whose address is matrixDiagonal[i]
  - Entries in row 'i' are ordered such that:
       - lower triangular terms are stored before the diagonal element.
       - upper triangular terms are stored after the diagonal element.
       - No other assumptions are made about entry ordering.

  Symmetric Gauss-Seidel notes:
  - We use the input vector x as the RHS and start with an initial guess for y of all zeros.
  - We perform one forward sweep.  Since y is initially zero we can ignore the upper triangular terms of A.
  - We then perform one back sweep.
       - For simplicity we include the diagonal contribution in the for-j loop, then correct the sum after

  @param[in]  A the known system matrix
  @param[in]  r the input vector
  @param[inout] x On entry, x should contain relevant values, on exit x contains the result of one symmetric GS sweep with r as the RHS.

  @return returns 0 upon success and non-zero otherwise

  @warning Early versions of this kernel (Version 1.1 and earlier) had the r and x arguments in reverse order, and out of sync with other kernels.

  @see ComputeSYMGS_ref
*/

sycl::event ComputeSYMGS( const SparseMatrix & A, const Vector & r, Vector & x, sycl::queue & main_queue,
                          int& ierr, const std::vector<sycl::event> & deps)
{

    try {
        struct optData *optData = (struct optData *)A.optimizationData;
        custom::sparseMatrix *sparseM = (custom::sparseMatrix *)optData->esbM;
        return run_SYMGS_custom(main_queue, A, optData, sparseM, r, x, deps);
    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception in SYMGS:\n" << e.what() << std::endl;
        ierr += 1;
        return sycl::event();
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception in SYMGS:\n" << e.what() << std::endl;
        ierr += 1;
        return sycl::event();
    }

    return sycl::event();
}


//
// ComputeSYMGS_MV which adds in y = A * x to the symgs operation on x,r
//
// For SYCL with number of smoothing steps = 1, this is the main function
// implementation
//
sycl::event ComputeSYMGS_MV( const SparseMatrix & A, const Vector & r, Vector & x, Vector & y,
                             sycl::queue & main_queue, int &ierr, const std::vector<sycl::event> & deps)
{

    try {
        struct optData *optData = (struct optData *)A.optimizationData;
        custom::sparseMatrix *sparseM = (custom::sparseMatrix *)optData->esbM;
        return run_SYMGS_MV_custom(main_queue, A, optData, sparseM, r, x, y, deps);
    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception in SYMGS_MV:\n" << e.what() << std::endl;
        ierr += 1;
        return sycl::event();
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception in SYMGS_MV:\n" << e.what() << std::endl;
        ierr += 1;
        return sycl::event();
    }

    return sycl::event();
}
