/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.evaluation;

import java.util.logging.Logger;
import org.tribuo.Output;
import org.tribuo.classification.Classifiable;
import org.tribuo.classification.evaluation.ConfusionMatrix;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricTarget;

public final class ConfusionMetrics {
    private static final Logger logger = Logger.getLogger(ConfusionMetrics.class.getName());

    private ConfusionMetrics() {
    }

    public static <T extends Classifiable<T>> double accuracy(MetricTarget<T> target, ConfusionMatrix<T> cm) {
        if (target.getOutputTarget().isPresent()) {
            return ConfusionMetrics.accuracy((Classifiable)target.getOutputTarget().get(), cm);
        }
        return ConfusionMetrics.accuracy((EvaluationMetric.Average)target.getAverageTarget().get(), cm);
    }

    public static <T extends Classifiable<T>> double accuracy(T label, ConfusionMatrix<T> cm) {
        double support = cm.support(label);
        if (support == 0.0) {
            logger.warning("No predictions for " + label + ": accuracy ill-defined");
            return Double.NaN;
        }
        return cm.tp(label) / cm.support(label);
    }

    public static <T extends Classifiable<T>> double accuracy(EvaluationMetric.Average average, ConfusionMatrix<T> cm) {
        if (average.equals((Object)EvaluationMetric.Average.MICRO)) {
            if (cm.support() == 0.0) {
                logger.warning("No predictions: accuracy ill-defined");
                return Double.NaN;
            }
            return cm.tp() / cm.support();
        }
        if (cm.getDomain().size() == 0) {
            logger.warning("Empty domain: accuracy ill-defined");
            return Double.NaN;
        }
        double total = 0.0;
        for (Classifiable output : cm.getDomain().getDomain()) {
            total += ConfusionMetrics.accuracy(output, cm);
        }
        return total / (double)cm.getDomain().size();
    }

    public static <T extends Classifiable<T>> double balancedErrorRate(ConfusionMatrix<T> cm) {
        if (cm.getDomain().size() == 0) {
            logger.warning("Empty domain: balanced error rate ill-defined");
            return Double.NaN;
        }
        double sr = 0.0;
        for (Classifiable output : cm.getDomain().getDomain()) {
            sr += ConfusionMetrics.recall(new MetricTarget((Output)output), cm);
        }
        return 1.0 - sr / (double)cm.getDomain().size();
    }

    private static <T extends Classifiable<T>> double compute(ConfusionFunction<T> fxn, MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
        return fxn.compute(tgt, (ConfusionMatrix<MetricTarget<T>>)cm);
    }

    public static <T extends Classifiable<T>> double tp(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
        return ConfusionMetrics.compute(ConfusionMetrics::tp, tgt, cm);
    }

    public static <T extends Classifiable<T>> double fp(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
        return ConfusionMetrics.compute(ConfusionMetrics::fp, tgt, cm);
    }

    public static <T extends Classifiable<T>> double tn(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
        return ConfusionMetrics.compute(ConfusionMetrics::tn, tgt, cm);
    }

    public static <T extends Classifiable<T>> double fn(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
        return ConfusionMetrics.compute(ConfusionMetrics::fn, tgt, cm);
    }

    private static double tp(double tp, double fp, double tn, double fn) {
        return tp;
    }

    private static double fp(double tp, double fp, double tn, double fn) {
        return fp;
    }

    private static double tn(double tp, double fp, double tn, double fn) {
        return tn;
    }

    private static double fn(double tp, double fp, double tn, double fn) {
        return fn;
    }

    public static <T extends Classifiable<T>> double precision(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
        return ConfusionMetrics.compute(ConfusionMetrics::precision, tgt, cm);
    }

    public static double precision(double tp, double fp, double tn, double fn) {
        double denom = tp + fp;
        return denom == 0.0 ? 0.0 : tp / denom;
    }

    public static <T extends Classifiable<T>> double recall(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
        return ConfusionMetrics.compute(ConfusionMetrics::recall, tgt, cm);
    }

    public static double recall(double tp, double fp, double tn, double fn) {
        double denom = tp + fn;
        return denom == 0.0 ? 0.0 : tp / denom;
    }

    public static <T extends Classifiable<T>> double f1(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
        return ConfusionMetrics.compute(ConfusionMetrics::f1, tgt, cm);
    }

    public static double f1(double tp, double fp, double tn, double fn) {
        return ConfusionMetrics.fscore(1.0, tp, fp, tn, fn);
    }

    public static double fscore(double beta, double tp, double fp, double tn, double fn) {
        double r;
        double bsq = beta * beta;
        double p = ConfusionMetrics.precision(tp, fp, tn, fn);
        double denom = bsq * p + (r = ConfusionMetrics.recall(tp, fp, tn, fn));
        return denom == 0.0 ? 0.0 : (1.0 + bsq) * p * r / denom;
    }

    public static <T extends Classifiable<T>> double fscore(MetricTarget<T> tgt, ConfusionMatrix<T> cm, double beta) {
        ConfusionFunction fxn = (tp, fp, tn, fn) -> ConfusionMetrics.fscore(beta, tp, fp, tn, fn);
        return ConfusionMetrics.compute(fxn, tgt, cm);
    }

    @FunctionalInterface
    private static interface ConfusionFunction<T extends Classifiable<T>> {
        public double compute(double var1, double var3, double var5, double var7);

        default public double compute(MetricTarget<T> tgt, ConfusionMatrix<T> cm) {
            if (tgt.getOutputTarget().isPresent()) {
                return this.compute((Classifiable)tgt.getOutputTarget().get(), cm);
            }
            if (tgt.getAverageTarget().isPresent()) {
                return this.compute((EvaluationMetric.Average)tgt.getAverageTarget().get(), cm);
            }
            throw new IllegalStateException("MetricTarget with no actual target");
        }

        default public double compute(T label, ConfusionMatrix<T> cm) {
            return this.compute(cm.tp(label), cm.fp(label), cm.tn(label), cm.fn(label));
        }

        default public double compute(EvaluationMetric.Average average, ConfusionMatrix<T> cm) {
            switch (average) {
                case MACRO: {
                    if (cm.getDomain().size() == 0) {
                        logger.warning("Empty domain: macro-average ill-defined.");
                        return Double.NaN;
                    }
                    double total = 0.0;
                    for (Classifiable output : cm.getDomain().getDomain()) {
                        total += this.compute(output, cm);
                    }
                    return total / (double)cm.getDomain().size();
                }
                case MICRO: {
                    if (cm.support() == 0.0) {
                        logger.warning("No predictions: micro-average ill-defined.");
                        return Double.NaN;
                    }
                    return this.compute(cm.tp(), cm.fp(), cm.tn(), cm.fn());
                }
            }
            throw new IllegalArgumentException("Unsupported average type: " + average.name());
        }
    }
}

