/*******************************************************************************
* Copyright (C) 2023 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

#pragma once

enum LBMask { maskL, maskB};

template <local_int_t BLOCK_SIZE, local_int_t uroll, LBMask lbmask, local_int_t s=0>
static inline __attribute__((always_inline))
void trmv_lbmv_unroll_impl(std::array<esimd::simd<local_int_t, BLOCK_SIZE>, uroll> &indices,
                            std::array<esimd::simd<double, BLOCK_SIZE>, uroll> &vals,
                            std::array<esimd::simd_mask<BLOCK_SIZE>, uroll> &mask,
                            std::array<esimd::simd<double, BLOCK_SIZE>, uroll> &x_vec,
                            esimd::simd<double, BLOCK_SIZE> &y_vec,
                            const double *values,
                            const local_int_t *colind,
                            const double *x,
                            const local_int_t nrows,
                            const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{
    if constexpr (s < uroll) {
        esimd::simd<double, BLOCK_SIZE> zero_vec(0);
        indices[s] = esimd_lsc_block_load<local_int_t, local_int_t, BLOCK_SIZE, st, uc>(colind, s * BLOCK_SIZE);
        if constexpr (lbmask == maskL) {
            mask[s] = (indices[s] >= offset) || (indices[s] < 0);
        }
        else if constexpr (lbmask == maskB) {
            mask[s] = indices[s] < nrows;
        }
        vals[s] = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, st, uc>(values, s * BLOCK_SIZE);
        x_vec[s] = esimd_lsc_gather<double, local_int_t, BLOCK_SIZE, ca, ca>(x, indices[s], !mask[s], zero_vec);
        y_vec += x_vec[s] * vals[s];
        // Instantiate next iteration of unroll
        trmv_lbmv_unroll_impl<BLOCK_SIZE, uroll, lbmask, s + 1>(indices, vals, mask, x_vec, y_vec, values, colind, x, nrows, offset);
    }
}

template <local_int_t BLOCK_SIZE, local_int_t uroll, local_int_t s=0>
static inline __attribute__((always_inline))
void trmv_ubmv_unroll_impl(std::array<esimd::simd<local_int_t, BLOCK_SIZE>, uroll> &indices,
                           std::array<esimd::simd<double, BLOCK_SIZE>, uroll> &vals,
                           std::array<esimd::simd_mask<BLOCK_SIZE>, uroll> &mask,
                           std::array<esimd::simd_mask<BLOCK_SIZE>, uroll> &mask1,
                           std::array<esimd::simd<double, BLOCK_SIZE>, uroll> &x_vec,
                           esimd::simd<double, BLOCK_SIZE> &y_vec,
                           esimd::simd<double, BLOCK_SIZE> &z_vec,
                           const double *values,
                           const local_int_t *colind,
                           const double *x,
                           const local_int_t nrows,
                           const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{
    if constexpr (s < uroll) {
        esimd::simd<double, BLOCK_SIZE> zero_vec(0);
        indices[s] = esimd_lsc_block_load<local_int_t, local_int_t, BLOCK_SIZE, st, uc>(colind, s * BLOCK_SIZE);
        mask[s] = indices[s] <= offset;
        mask1[s] = indices[s] >= nrows;
        vals[s] = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, st, uc>(values, s * BLOCK_SIZE);
        x_vec[s] = esimd_lsc_gather<double, local_int_t, BLOCK_SIZE, ca, ca>(x, indices[s], !mask[s], zero_vec);
        y_vec += x_vec[s] * vals[s];
        x_vec[s].merge(zero_vec, mask1[s]);
        z_vec += x_vec[s] * vals[s];

        // Instantiate next iteration of unroll
        trmv_ubmv_unroll_impl<BLOCK_SIZE, uroll, s + 1>(indices, vals, mask, mask1, x_vec, y_vec, z_vec, values, colind, x, nrows, offset);
    }
}

template <local_int_t BLOCK_SIZE, local_int_t uroll, local_int_t s=0>
static inline __attribute__((always_inline))
void mv_unroll_impl(std::array<esimd::simd<local_int_t, BLOCK_SIZE>, uroll> &indices,
                    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> &vals,
                    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> &x_vec,
                    esimd::simd<double, BLOCK_SIZE> &y_vec,
                    const double *values,
                    const local_int_t *colind,
                    const double *x)
{
    if constexpr (s < uroll) {
		esimd::simd<double, BLOCK_SIZE> zero_vec(0);
        indices[s] = esimd_lsc_block_load<local_int_t, local_int_t, BLOCK_SIZE, st, uc>(colind, s * BLOCK_SIZE);
		auto mask = indices[s] >= 0;
        vals[s] = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, st, uc>(values, s * BLOCK_SIZE);
        x_vec[s] = esimd_lsc_gather<double, local_int_t, BLOCK_SIZE, ca, ca>(x, indices[s], mask, zero_vec);
        y_vec += x_vec[s] * vals[s];
        // Instantiate next iteration of unroll
        mv_unroll_impl<BLOCK_SIZE, uroll, s + 1>(indices, vals, x_vec, y_vec, values, colind, x);
    }
}

enum ESBTRSVType { ESBTRSVL, ESBTRSVU };

template <local_int_t BLOCK_SIZE, local_int_t uroll, ESBTRSVType trsv, local_int_t s=0>
static inline __attribute__((always_inline))
void trsv_unroll_impl(std::array<esimd::simd<local_int_t, BLOCK_SIZE>, uroll> &cols,
                      std::array<esimd::simd<double, BLOCK_SIZE>, uroll> &vals,
                      std::array<esimd::simd_mask<BLOCK_SIZE>, uroll> &uplomask,
                      std::array<esimd::simd<double, BLOCK_SIZE>, uroll> &y_vec,
                      esimd::simd<double, BLOCK_SIZE> &t_vec,
                      const double *values,
                      const local_int_t *colind,
                      double *y,
                      const local_int_t nrows,
                      const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{
    if constexpr (s < uroll) {

        esimd::simd<double, BLOCK_SIZE> zero_vec(0);
        cols[s] = esimd_lsc_block_load<local_int_t, local_int_t, BLOCK_SIZE, st, uc>(colind, s * BLOCK_SIZE);

        if constexpr (trsv == ESBTRSVL) {
            uplomask[s] = (cols[s] < offset) && (cols[s] >= 0); // lower + not fill-in
        }
        else {
            uplomask[s] = (cols[s] > offset) && (cols[s] < nrows); // upper + local + not fill-in (covered by upper check already)
        }

        y_vec[s] = esimd_lsc_gather<double, local_int_t, BLOCK_SIZE, ca, ca>(y, cols[s], uplomask[s], zero_vec);
        vals[s] = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, st, uc>(values, s * BLOCK_SIZE);
        vals[s].merge(zero_vec, !uplomask[s]);

        t_vec += y_vec[s] * vals[s];
        // Instantiate next iteration of unroll
        trsv_unroll_impl<BLOCK_SIZE, uroll, trsv, s + 1>(cols, vals, uplomask, y_vec, t_vec, values, colind, y, nrows, offset);
    }
}

//
// Unrolls
//
template <local_int_t BLOCK_SIZE, local_int_t uroll, LBMask lbmask>
static inline void trmv_lbmv_unroll(esimd::simd<double, BLOCK_SIZE> &y_vec,
                                     const double *values,
                                     const local_int_t *colind,
                                     const double *x,
                                     const local_int_t nrows,
                                     const esimd::simd<local_int_t, BLOCK_SIZE> &offset) {

    std::array<esimd::simd<local_int_t, BLOCK_SIZE>, uroll> indices;
    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> vals;
    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> x_vec;
    std::array<esimd::simd_mask<BLOCK_SIZE>, uroll> mask;
    trmv_lbmv_unroll_impl<BLOCK_SIZE, uroll, lbmask>(indices, vals, mask, x_vec, y_vec, values, colind, x, nrows, offset);
}

template <local_int_t BLOCK_SIZE, local_int_t uroll>
static inline void trmv_ubmv_unroll(esimd::simd<double, BLOCK_SIZE> &y_vec,
                                    esimd::simd<double, BLOCK_SIZE> &z_vec,
                                    const double *values,
                                    const local_int_t *colind,
                                    const double *x,
                                    const local_int_t nrows,
                                    const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{
    std::array<esimd::simd<local_int_t, BLOCK_SIZE>, uroll> indices;
    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> vals;
    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> x_vec;
    std::array<esimd::simd_mask<BLOCK_SIZE>, uroll> mask;
    std::array<esimd::simd_mask<BLOCK_SIZE>, uroll> mask1;
    trmv_ubmv_unroll_impl<BLOCK_SIZE, uroll>(indices, vals, mask, mask1, x_vec, y_vec, z_vec, values, colind, x, nrows, offset);
}

template <local_int_t BLOCK_SIZE, local_int_t uroll>
static inline void mv_unroll(esimd::simd<double, BLOCK_SIZE> &y_vec,
                             const double *values,
                             const local_int_t *colind,
                             const double *x)
{
    std::array<esimd::simd<local_int_t, BLOCK_SIZE>, uroll> indices;
    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> vals;
    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> x_vec;
    mv_unroll_impl<BLOCK_SIZE, uroll>(indices, vals, x_vec, y_vec, values, colind, x);
}

template <local_int_t BLOCK_SIZE, local_int_t uroll, ESBTRSVType trsv>
static inline void trsv_unroll(esimd::simd<double, BLOCK_SIZE> &t_vec,
                               const double *values,
                               const local_int_t *colind,
                               double *y,
                               const local_int_t nrows,
                               const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{
    std::array<esimd::simd<local_int_t, BLOCK_SIZE>, uroll> cols;
    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> vals;
    std::array<esimd::simd<double, BLOCK_SIZE>, uroll> y_vec;
    std::array<esimd::simd_mask<BLOCK_SIZE>, uroll> uplomask;
    trsv_unroll_impl<BLOCK_SIZE, uroll, trsv>(cols, vals, uplomask, y_vec, t_vec, values, colind, y, nrows, offset);
}

//
// Handle generic case
//
template <local_int_t BLOCK_SIZE, LBMask lbmask>
static inline void trmv_lbmv_unroll_generic(const local_int_t st_vec,
                                             const local_int_t en_vec,
                                             esimd::simd<double, BLOCK_SIZE> &y_vec,
                                             const double *values,
                                             const local_int_t *colind,
                                             const double *x,
                                             const local_int_t nrows,
                                             const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{

    local_int_t k = en_vec - st_vec;
    local_int_t k16 = (k / 16) * 16;
    local_int_t k8 = (k / 8) * 8;
    local_int_t k4 = (k / 4) * 4;
    local_int_t k2 = (k / 2) * 2;
    const local_int_t en_vec16 = st_vec + k16;
    const local_int_t en_vec8 = st_vec + k8;
    const local_int_t en_vec4 = st_vec + k4;
    const local_int_t en_vec2 = st_vec + k2;

    local_int_t j = st_vec;
    if constexpr (BLOCK_SIZE == 32) {
        _Pragma("unroll")
        for (; j < en_vec16; j += 16)
            trmv_lbmv_unroll<BLOCK_SIZE, 16, lbmask>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);
    }

    _Pragma("unroll")
    for (; j < en_vec8; j += 8)
        trmv_lbmv_unroll<BLOCK_SIZE, 8, lbmask>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);

    _Pragma("unroll")
    for (; j < en_vec4; j += 4)
        trmv_lbmv_unroll<BLOCK_SIZE, 4, lbmask>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);

    _Pragma("unroll")
    for (; j < en_vec2; j += 2)
        trmv_lbmv_unroll<BLOCK_SIZE, 2, lbmask>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);

    for (; j < en_vec; j += 1)
        trmv_lbmv_unroll<BLOCK_SIZE, 1, lbmask>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);
}

template <local_int_t BLOCK_SIZE, ESBTRSVType trsv>
static inline void trsv_unroll_generic(const local_int_t st_vec,
                                       const local_int_t en_vec,
                                       esimd::simd<double, BLOCK_SIZE> &t_vec,
                                       const double *values,
                                       const local_int_t *colind,
                                       double *y,
                                       const local_int_t nrows,
                                       const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{

    local_int_t k = en_vec - st_vec;
    local_int_t k16 = (k / 16) * 16;
    local_int_t k8 = (k / 8) * 8;
    local_int_t k4 = (k / 4) * 4;
    local_int_t k2 = (k / 2) * 2;
    const local_int_t en_vec16 = st_vec + k16;
    const local_int_t en_vec8 = st_vec + k8;
    const local_int_t en_vec4 = st_vec + k4;
    const local_int_t en_vec2 = st_vec + k2;

    local_int_t j = st_vec;

	if constexpr (BLOCK_SIZE == 32) {
    	_Pragma("unroll")
   		for (; j < en_vec16; j += 16)
        	trsv_unroll<BLOCK_SIZE, 16, trsv>(t_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, y, nrows, offset);
	}

    _Pragma("unroll")
    for (; j < en_vec8; j += 8)
        trsv_unroll<BLOCK_SIZE, 8, trsv>(t_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, y, nrows, offset);

    _Pragma("unroll")
    for (; j < en_vec4; j += 4)
        trsv_unroll<BLOCK_SIZE, 4, trsv>(t_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, y, nrows, offset);

    _Pragma("unroll")
    for (; j < en_vec2; j += 2)
        trsv_unroll<BLOCK_SIZE, 2, trsv>(t_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, y, nrows, offset);

    for (; j < en_vec; j += 1)
        trsv_unroll<BLOCK_SIZE, 1, trsv>(t_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, y, nrows, offset);
}

//
// Handle generic case
//
template <local_int_t BLOCK_SIZE>
static inline void trmv_ubmv_unroll_generic(
    const local_int_t st_vec,
    const local_int_t en_vec,
    esimd::simd<double, BLOCK_SIZE> &y_vec,
    esimd::simd<double, BLOCK_SIZE> &z_vec,
    const double *values,
    const local_int_t *colind,
    const double *x,
    const local_int_t nrows,
    const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{

    local_int_t k = en_vec - st_vec;
    local_int_t k16 = (k / 16) * 16;
    local_int_t k8 = (k / 8) * 8;
    local_int_t k4 = (k / 4) * 4;
    local_int_t k2 = (k / 2) * 2;
    const local_int_t en_vec16 = st_vec + k16;
    const local_int_t en_vec8 = st_vec + k8;
    const local_int_t en_vec4 = st_vec + k4;
    const local_int_t en_vec2 = st_vec + k2;

    local_int_t j = st_vec;

	if constexpr (BLOCK_SIZE == 32) {
    	_Pragma("unroll")
   		for (; j < en_vec16; j += 16)
        	trmv_ubmv_unroll<BLOCK_SIZE, 16>(y_vec, z_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);
	}

    _Pragma("unroll")
    for (; j < en_vec8; j += 8)
        trmv_ubmv_unroll<BLOCK_SIZE, 8>(y_vec, z_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);

    _Pragma("unroll")
    for (; j < en_vec4; j += 4)
        trmv_ubmv_unroll<BLOCK_SIZE, 4>(y_vec, z_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);

    _Pragma("unroll")
    for (; j < en_vec2; j += 2)
        trmv_ubmv_unroll<BLOCK_SIZE, 2>(y_vec, z_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);

    for (; j < en_vec; j += 1)
        trmv_ubmv_unroll<BLOCK_SIZE, 1>(y_vec, z_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x, nrows, offset);
}

// Handle generic case
template <local_int_t BLOCK_SIZE>
static inline void mv_unroll_generic(const local_int_t st_vec,
                                     const local_int_t en_vec,
                                     esimd::simd<double, BLOCK_SIZE> &y_vec,
                                     const double *values,
                                     const local_int_t *colind,
                                     const double *x)
{

    local_int_t k = en_vec - st_vec;
    local_int_t k16 = (k / 16) * 16;
    local_int_t k8 = (k / 8) * 8;
    local_int_t k4 = (k / 4) * 4;
    local_int_t k2 = (k / 2) * 2;
    const local_int_t en_vec16 = st_vec + k16;
    const local_int_t en_vec8 = st_vec + k8;
    const local_int_t en_vec4 = st_vec + k4;
    const local_int_t en_vec2 = st_vec + k2;

    local_int_t j = st_vec;

    _Pragma("unroll")
    for (; j < en_vec16; j += 16)
        mv_unroll<BLOCK_SIZE, 16>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x);

    _Pragma("unroll")
    for (; j < en_vec8; j += 8)
        mv_unroll<BLOCK_SIZE, 8>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x);

    _Pragma("unroll")
    for (; j < en_vec4; j += 4)
        mv_unroll<BLOCK_SIZE, 4>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x);

    _Pragma("unroll")
    for (; j < en_vec2; j += 2)
        mv_unroll<BLOCK_SIZE, 2>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x);

    for (; j < en_vec; j += 1)
        mv_unroll<BLOCK_SIZE, 1>(y_vec, values + j * BLOCK_SIZE, colind + j * BLOCK_SIZE, x);
}

constexpr local_int_t trmv_lbmv_dispatch_cases[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27};
constexpr local_int_t trmv_ubmv_dispatch_cases[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27};
constexpr local_int_t mv_dispatch_cases[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27};
//constexpr local_int_t trmv_lbmv_dispatch_cases[] = {1,3,4,5,6,7,9,10,12,13,14,15,18,19,21,27};
//constexpr local_int_t trmv_ubmv_dispatch_cases[] = {27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1};
constexpr local_int_t trsv_dispatch_cases[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27};

//
// Dispatchers
//
template <local_int_t BLOCK_SIZE, LBMask lbmask, local_int_t N = std::size(trmv_lbmv_dispatch_cases), local_int_t i = 0>
static inline __attribute__((always_inline))
void trmv_lbmv_unroll_dispatch(const local_int_t st_vec,
                                const local_int_t en_vec,
                                esimd::simd<double, BLOCK_SIZE> &y_vec,
                                const double *values,
                                const local_int_t *colind,
                                const double *x,
                                const local_int_t nrows,
                                const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{
    if constexpr (i < N) {
        static_assert(i < std::size(trmv_lbmv_dispatch_cases));

        constexpr local_int_t nvec = trmv_lbmv_dispatch_cases[i];
        // Instantiate unrolls for size i
        if (en_vec - st_vec == nvec)
            trmv_lbmv_unroll<BLOCK_SIZE, nvec, lbmask>(y_vec, values + st_vec * BLOCK_SIZE, colind + st_vec * BLOCK_SIZE, x, nrows, offset);
        else
            trmv_lbmv_unroll_dispatch<BLOCK_SIZE, lbmask, N, i + 1>(st_vec, en_vec, y_vec, values, colind, x, nrows, offset);
    }
    else if constexpr (i == N) {
        // Dispatch to generic case
        trmv_lbmv_unroll_generic<BLOCK_SIZE, lbmask>(st_vec, en_vec, y_vec, values, colind, x, nrows, offset);
    }
}

template <local_int_t BLOCK_SIZE, local_int_t N = std::size(trmv_ubmv_dispatch_cases), local_int_t i = 0>
static inline __attribute__((always_inline))
void trmv_ubmv_unroll_dispatch(const local_int_t st_vec,
                               const local_int_t en_vec,
                               esimd::simd<double, BLOCK_SIZE> &y_vec,
                               esimd::simd<double, BLOCK_SIZE> &z_vec,
                               const double *values,
                               const local_int_t *colind,
                               const double *x,
                               const local_int_t nrows,
                               const local_int_t *blockptr_st, // Not used
                               local_int_t block, // Not used
                               const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{
    if constexpr (i < N) {
        static_assert(i < std::size(trmv_ubmv_dispatch_cases));

        constexpr local_int_t nvec = trmv_ubmv_dispatch_cases[i];
        // Instantiate unrolls for size nvec
        if (en_vec - st_vec == nvec)
            trmv_ubmv_unroll<BLOCK_SIZE, nvec>(y_vec, z_vec, values + st_vec * BLOCK_SIZE, colind + st_vec * BLOCK_SIZE, x, nrows, offset);
        else
            trmv_ubmv_unroll_dispatch<BLOCK_SIZE, N, i + 1>(st_vec, en_vec, y_vec, z_vec, values, colind, x, nrows, blockptr_st, block, offset);
    }
    else if constexpr (i == N) {
        // Dispatch to generic case
        trmv_ubmv_unroll_generic<BLOCK_SIZE>(st_vec, en_vec, y_vec, z_vec, values, colind, x, nrows, offset);
    }
}

template <local_int_t BLOCK_SIZE, local_int_t N = std::size(mv_dispatch_cases), local_int_t i = 0>
static inline __attribute__((always_inline))
void mv_unroll_dispatch(const local_int_t st_vec,
                        const local_int_t en_vec,
                        esimd::simd<double, BLOCK_SIZE> &y_vec,
                        const double *values,
                        const local_int_t *colind,
                        const double *x)
{
    if constexpr (i < N) {
        static_assert(i < std::size(mv_dispatch_cases));

        constexpr local_int_t nvec = mv_dispatch_cases[i];
        // Instantiate unrolls for size i
        if (en_vec - st_vec == nvec)
            mv_unroll<BLOCK_SIZE, nvec>(y_vec, values + st_vec * BLOCK_SIZE, colind + st_vec * BLOCK_SIZE, x);
        else
            mv_unroll_dispatch<BLOCK_SIZE, N, i + 1>(st_vec, en_vec, y_vec, values, colind, x);
    }
    else if constexpr (i == N) {
        // Dispatch to generic case
        mv_unroll_generic<BLOCK_SIZE>(st_vec, en_vec, y_vec, values, colind, x);
    }
}


template <local_int_t BLOCK_SIZE, ESBTRSVType trsv, local_int_t N = std::size(trsv_dispatch_cases), local_int_t i = 0>
static inline __attribute__((always_inline))
void trsv_unroll_dispatch(const local_int_t st_vec,
                          const local_int_t en_vec,
                          esimd::simd<double, BLOCK_SIZE> &t_vec,
                          const double *values,
                          const local_int_t *colind,
                          double *y,
                          const local_int_t nrows,
                          const esimd::simd<local_int_t, BLOCK_SIZE> &offset)
{
    if constexpr (i < N) {
        static_assert(i < std::size(trsv_dispatch_cases));

        constexpr local_int_t nvec = trsv_dispatch_cases[i];
        // Instantiate unrolls for size i
        if (en_vec - st_vec == nvec)
            trsv_unroll<BLOCK_SIZE, nvec, trsv>(t_vec, values + st_vec * BLOCK_SIZE, colind + st_vec * BLOCK_SIZE, y, nrows, offset);
        else
            trsv_unroll_dispatch<BLOCK_SIZE, trsv, N, i + 1>(st_vec, en_vec, t_vec, values, colind, y, nrows, offset);
    }
    else if constexpr (i == N) {
        // Dispatch to generic case
        trsv_unroll_generic<BLOCK_SIZE, trsv>(st_vec, en_vec, t_vec, values, colind, y, nrows, offset);
    }
}
