/*******************************************************************************
* Copyright (C) 2020 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates usage of DPC++ buffer-based API for oneMKL RNG
*       discrete distributions with oneapi::mkl::rng::default_engine
*       random number generator
*
*       Continious distribution list:
*           oneapi::mkl::rng::bernoulli        (available for CPU and GPU)
*           oneapi::mkl::rng::binomial         (available for CPU and CPU)
*           oneapi::mkl::rng::geometric        (available for CPU and GPU)
*           oneapi::mkl::rng::hypergeometric   (available for CPU and GPU)
*           oneapi::mkl::rng::multinomial      (available for CPU and GPU)
*           oneapi::mkl::rng::negative_binomial(available for CPU and GPU)
*           oneapi::mkl::rng::poisson          (available for CPU and GPU)
*           oneapi::mkl::rng::poisson_v        (available for CPU and GPU)
*           oneapi::mkl::rng::uniform_bits     (available for CPU and GPU)
*           oneapi::mkl::rng::bits             (available for CPU and GPU)

*       The supported data types for random numbers are:
*           int32_t
*           uint32_t
*           uint64_t (for uniform_bits)
*
*
*******************************************************************************/

// stl includes
#include <iostream>
#include <vector>
#include <string>

#include <sycl/sycl.hpp>
#include "oneapi/mkl.hpp"

// local includes
#include "../include/common_for_rng_examples.hpp"

// example parameters defines
constexpr std::size_t n = 2000;
constexpr std::size_t n_print = 10;

template <typename IntType, typename Distribution, bool is_validation_needed = true>
bool perform_generation_validation(sycl::queue& queue, const Distribution& distr,
                                   const std::string& distr_name, std::size_t buf_size = n) {
    // create a default generate
    oneapi::mkl::rng::default_engine engine(queue);

    // prepare array for random numbers
    sycl::buffer<IntType> r_buffer(buf_size);

    try {
        // call oneMKL generation
        oneapi::mkl::rng::generate(distr, engine, n, r_buffer);
        queue.wait_and_throw();
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl
                  << "Error code: " << e.code().value() << std::endl;
        return false;
    }
    catch (oneapi::mkl::exception const& e) {
        std::cout << "\toneMKL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl;
        return false;
    }

    // print output
    std::cout << "\n\t\tOutput of generator with the " << distr_name <<" distribution:" << std::endl;
    print_output(r_buffer, n_print);

    // validation
    if constexpr (is_validation_needed) {
        auto r_acc = sycl::host_accessor(r_buffer, sycl::read_only);
        return check_statistics(r_acc.get_pointer(), buf_size, distr);
    }
    else
        return true;
}

template <typename IntType>
bool run_bernoulli_example(sycl::queue& queue) {
    // set success probability
    float p = 0.5f;

    oneapi::mkl::rng::bernoulli<IntType> distribution(p);

    return perform_generation_validation<IntType>(queue, distribution, "bernoulli");
}

template <typename IntType>
bool run_binomial_example(sycl::queue& queue) {
    // set number of independent trials
    std::int32_t ntrial = 10;
    // set probability
    double p = 0.5;

    oneapi::mkl::rng::binomial<IntType> distribution(ntrial, p);

    return perform_generation_validation<IntType>(queue, distribution, "binomial");
}

template <typename IntType>
bool run_geometric_example(sycl::queue& queue) {
    // set probability of trial
    float p = 0.5f;

    oneapi::mkl::rng::geometric<IntType> distribution(p);

    return perform_generation_validation<IntType>(queue, distribution, "geometric");
}

template <typename IntType>
bool run_hypergeometric_example(sycl::queue& queue) {
    // set lot size
    std::int32_t l = 1;
    // set size of sampling without replacement
    std::int32_t s = 1;
    // set number of marked elements
    std::int32_t m = 1;

    oneapi::mkl::rng::hypergeometric<IntType> distribution(l, s, m);

    return perform_generation_validation<IntType>(queue, distribution, "hypergeometric");
}

template <typename IntType>
bool run_multinomial_example(sycl::queue& queue) {
    // set number of independent trials
    std::int32_t ntrial = 10;
    // set probability vector of possible outcomes
    std::vector<double> p(165, 1.0 / 165.0);

    oneapi::mkl::rng::multinomial<IntType> distribution(ntrial, sycl::span{ p.data(), p.size() });

    return perform_generation_validation<IntType>(queue, distribution, "multinomial", n * p.size());
}

template <typename IntType>
bool run_negative_binomial_example(sycl::queue& queue) {
    // set the first distribution parameter
    double a = 1.0;
    // set the second distribution parameter
    double p = 0.75;

    oneapi::mkl::rng::negative_binomial<IntType> distribution(a, p);

    return perform_generation_validation<IntType>(queue, distribution, "negative binomial");
}

template <typename IntType>
bool run_poisson_example(sycl::queue& queue) {
    // set distribution parameter
    double lambda = 1.0;

    oneapi::mkl::rng::poisson<IntType> distribution(lambda);

    return perform_generation_validation<IntType>(queue, distribution, "poisson");
}

template <typename IntType>
bool run_poisson_v_example(sycl::queue& queue) {
    std::vector<double> lambda(n, 1.0);

    oneapi::mkl::rng::poisson_v<IntType> distribution(sycl::span{ lambda.data(), lambda.size() });

    return perform_generation_validation<IntType>(queue, distribution, "poisson_v");
}

template <typename IntType>
bool run_uniform_bits_example(sycl::queue& queue) {
    // distribution provides Uniformly distributed bits in 32/64-bit chunks
    oneapi::mkl::rng::uniform_bits<IntType> distribution;

    return perform_generation_validation<IntType, decltype(distribution), /*is_validate*/ false>(
        queue, distribution, "uniform bits");
}

template <typename IntType>
bool run_bits_example(sycl::queue& queue) {
    // distributions provides  bits of underlying engine integer reccurents
    oneapi::mkl::rng::bits<IntType> distribution;

    return perform_generation_validation<IntType, decltype(distribution), /*is_validate*/ false>(
        queue, distribution, "bits");
}

void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << "# Generate random numbers with discrete rng distributions:" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   oneapi::mkl::rng::bernoulli" << std::endl;
    std::cout << "#   oneapi::mkl::rng::binomial" << std::endl;
    std::cout << "#   oneapi::mkl::rng::geometric" << std::endl;
    std::cout << "#   oneapi::mkl::rng::hypergeometric" << std::endl;
    std::cout << "#   oneapi::mkl::rng::multinomial" << std::endl;
    std::cout << "#   oneapi::mkl::rng::negative_binomial" << std::endl;
    std::cout << "#   oneapi::mkl::rng::poisson" << std::endl;
    std::cout << "#   oneapi::mkl::rng::poisson_v" << std::endl;
    std::cout << "#   oneapi::mkl::rng::uniform_bits" << std::endl;
    std::cout << "#   oneapi::mkl::rng::bits" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported types:" << std::endl;
    std::cout << "#   int32_t uint32_t uint64_t" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//

int main(int argc, char** argv) {
    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";
            sycl::queue queue(my_dev, exception_handler);
            if (!run_bernoulli_example<std::int32_t>(queue) ||
                !run_geometric_example<std::int32_t>(queue) ||
                !run_poisson_example<std::int32_t>(queue) ||
                !run_uniform_bits_example<std::uint32_t>(queue) ||
                !run_bits_example<std::uint32_t>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
            if (isDoubleSupported(my_dev)) { // Double precision is used inside binomial and
                // hypergeometric implementations
                if (!run_binomial_example<std::int32_t>(queue) ||
                    !run_hypergeometric_example<std::int32_t>(queue) ||
                    !run_negative_binomial_example<std::int32_t>(queue) ||
                    !run_multinomial_example<std::int32_t>(queue) ||
                    !run_poisson_v_example<std::int32_t>(queue)) {
                    std::cout << "FAILED" << std::endl;
                    return 1;
                }
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices is enabled.\n";
            std::cout << "FAILED" << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }
    std::cout << "PASSED" << std::endl;
    return 0;
}
