/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning.tree;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RANKER_TYPE;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.RankerFactory;
import ciir.umass.edu.learning.Sampler;
import ciir.umass.edu.learning.tree.Ensemble;
import ciir.umass.edu.learning.tree.FeatureHistogram;
import ciir.umass.edu.learning.tree.LambdaMART;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.parsing.ModelLineProducer;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;

public class RFRanker
extends Ranker {
    public static int nBag = 300;
    public static float subSamplingRate = 1.0f;
    public static float featureSamplingRate = 0.3f;
    public static RANKER_TYPE rType = RANKER_TYPE.MART;
    public static int nTrees = 1;
    public static int nTreeLeaves = 100;
    public static float learningRate = 0.1f;
    public static int nThreshold = 256;
    public static int minLeafSupport = 1;
    protected Ensemble[] ensembles = null;

    public RFRanker() {
    }

    public RFRanker(List<RankList> samples, int[] features, MetricScorer scorer) {
        super(samples, features, scorer);
    }

    @Override
    public void init() {
        this.PRINT("Initializing... ");
        this.ensembles = new Ensemble[nBag];
        LambdaMART.nTrees = nTrees;
        LambdaMART.nTreeLeaves = nTreeLeaves;
        LambdaMART.learningRate = learningRate;
        LambdaMART.nThreshold = nThreshold;
        LambdaMART.minLeafSupport = minLeafSupport;
        LambdaMART.nRoundToStopEarly = -1;
        FeatureHistogram.samplingRate = featureSamplingRate;
        this.PRINTLN("[Done]");
    }

    @Override
    public void learn() {
        RankerFactory rf = new RankerFactory();
        this.PRINTLN("------------------------------------");
        this.PRINTLN("Training starts...");
        this.PRINTLN("------------------------------------");
        this.PRINTLN(new int[]{9, 9, 11}, new String[]{"bag", this.scorer.name() + "-B", this.scorer.name() + "-OOB"});
        this.PRINTLN("------------------------------------");
        double[] impacts = null;
        for (int i = 0; i < nBag; ++i) {
            if (i % LambdaMART.gcCycle == 0) {
                System.gc();
            }
            Sampler sp = new Sampler();
            List<RankList> bag = sp.doSampling(this.samples, subSamplingRate, true);
            LambdaMART r = (LambdaMART)rf.createRanker(rType, bag, this.features, this.scorer);
            boolean tmp = Ranker.verbose;
            Ranker.verbose = false;
            r.init();
            r.learn();
            if (impacts == null) {
                impacts = r.impacts;
            } else {
                for (int ftr = 0; ftr < impacts.length; ++ftr) {
                    int n = ftr;
                    impacts[n] = impacts[n] + r.impacts[ftr];
                }
            }
            Ranker.verbose = tmp;
            this.PRINTLN(new int[]{9, 9}, new String[]{"b[" + (i + 1) + "]", SimpleMath.round(r.getScoreOnTrainingData(), 4) + ""});
            this.ensembles[i] = r.getEnsemble();
        }
        this.scoreOnTrainingData = this.scorer.score(this.rank(this.samples));
        this.PRINTLN("------------------------------------");
        this.PRINTLN("Finished sucessfully.");
        this.PRINTLN(this.scorer.name() + " on training data: " + SimpleMath.round(this.scoreOnTrainingData, 4));
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(this.rank(this.validationSamples));
            this.PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        this.PRINTLN("------------------------------------");
        this.PRINTLN("-- FEATURE IMPACTS");
        int[] ftrsSorted = MergeSorter.sort(impacts, false);
        for (int i = 0; i < ftrsSorted.length; ++i) {
            int ftr = ftrsSorted[i];
            this.PRINTLN(" Feature " + this.features[ftr] + " reduced error " + impacts[ftr]);
        }
        this.PRINTLN("");
    }

    @Override
    public double eval(DataPoint dp) {
        double s = 0.0;
        for (int i = 0; i < this.ensembles.length; ++i) {
            s += (double)this.ensembles[i].eval(dp);
        }
        return s / (double)this.ensembles.length;
    }

    @Override
    public Ranker createNew() {
        return new RFRanker();
    }

    @Override
    public String toString() {
        String str = "";
        for (int i = 0; i < nBag; ++i) {
            str = str + this.ensembles[i].toString() + "\n";
        }
        return str;
    }

    @Override
    public String model() {
        String output = "## " + this.name() + "\n";
        output = output + "## No. of bags = " + nBag + "\n";
        output = output + "## Sub-sampling = " + subSamplingRate + "\n";
        output = output + "## Feature-sampling = " + featureSamplingRate + "\n";
        output = output + "## No. of trees = " + nTrees + "\n";
        output = output + "## No. of leaves = " + nTreeLeaves + "\n";
        output = output + "## No. of threshold candidates = " + nThreshold + "\n";
        output = output + "## Learning rate = " + learningRate + "\n";
        output = output + "\n";
        output = output + this.toString();
        return output;
    }

    @Override
    public void loadFromString(String fullText) {
        try {
            String content = "";
            ArrayList ens = new ArrayList();
            ModelLineProducer lineByLine = new ModelLineProducer();
            lineByLine.parse(fullText, (model, maybeEndEns) -> {
                String modelAsStr;
                if (maybeEndEns && (modelAsStr = model.toString()).endsWith("</ensemble>")) {
                    ens.add(new Ensemble(modelAsStr));
                    model.setLength(0);
                }
            });
            HashSet<Integer> uniqueFeatures = new HashSet<Integer>();
            this.ensembles = new Ensemble[ens.size()];
            for (int i = 0; i < ens.size(); ++i) {
                this.ensembles[i] = (Ensemble)ens.get(i);
                int[] fids = ((Ensemble)ens.get(i)).getFeatures();
                for (int f = 0; f < fids.length; ++f) {
                    if (uniqueFeatures.contains(fids[f])) continue;
                    uniqueFeatures.add(fids[f]);
                }
            }
            int fi = 0;
            this.features = new int[uniqueFeatures.size()];
            for (Integer f : uniqueFeatures) {
                this.features[fi++] = f;
            }
            System.out.println("Other Loading Done");
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in RFRanker::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        this.PRINTLN("No. of bags: " + nBag);
        this.PRINTLN("Sub-sampling: " + subSamplingRate);
        this.PRINTLN("Feature-sampling: " + featureSamplingRate);
        this.PRINTLN("No. of trees: " + nTrees);
        this.PRINTLN("No. of leaves: " + nTreeLeaves);
        this.PRINTLN("No. of threshold candidates: " + nThreshold);
        this.PRINTLN("Learning rate: " + learningRate);
    }

    @Override
    public String name() {
        return "Random Forests";
    }

    public Ensemble[] getEnsembles() {
        return this.ensembles;
    }
}

