/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.forecast.ml;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.parkservices.ForecastDescriptor;
import com.amazon.randomcutforest.parkservices.RCFCaster;
import com.amazon.randomcutforest.parkservices.config.Calibration;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.forecast.indices.ForecastIndex;
import org.opensearch.forecast.indices.ForecastIndexManagement;
import org.opensearch.forecast.ml.RCFCasterResult;
import org.opensearch.forecast.model.ForecastResult;
import org.opensearch.forecast.model.Forecaster;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AnalysisType;
import org.opensearch.timeseries.NodeStateManager;
import org.opensearch.timeseries.feature.FeatureManager;
import org.opensearch.timeseries.feature.SearchFeatureDao;
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.util.ModelUtil;
import org.opensearch.timeseries.util.ParseUtils;

public class ForecastColdStart
extends ModelColdStart<RCFCaster, ForecastIndex, ForecastIndexManagement, ForecastResult> {
    private static final Logger logger = LogManager.getLogger(ForecastColdStart.class);

    public ForecastColdStart(Clock clock, ThreadPool threadPool, NodeStateManager nodeStateManager, int rcfSampleSize, int numberOfTrees, int numMinSamples, SearchFeatureDao searchFeatureDao, double thresholdMinPvalue, FeatureManager featureManager, Duration modelTtl, int coolDownMinutes, long rcfSeed, int defaultTrainSamples, int maxRoundofColdStart, int resultSchemaVersion) {
        super(modelTtl, coolDownMinutes, clock, threadPool, numMinSamples, rcfSeed, numberOfTrees, rcfSampleSize, thresholdMinPvalue, nodeStateManager, 1, defaultTrainSamples, searchFeatureDao, featureManager, maxRoundofColdStart, "forecast-threadpool", AnalysisType.FORECAST, resultSchemaVersion);
    }

    @Override
    protected List<ForecastResult> trainModelFromDataSegments(List<Sample> pointSamples, ModelState<RCFCaster> modelState, Config config, String taskId) {
        List descriptors;
        if (pointSamples == null || pointSamples.size() == 0) {
            logger.info("Return early since data points must not be empty.");
            return null;
        }
        double[] firstPoint = pointSamples.get(0).getValueList();
        if (firstPoint == null || firstPoint.length == 0) {
            logger.info("Return early since data points must not be empty.");
            return null;
        }
        int shingleSize = config.getShingleSize();
        int forecastHorizon = ((Forecaster)config).getHorizon();
        int dimensions = firstPoint.length * shingleSize;
        RCFCaster.Builder casterBuilder = (RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)RCFCaster.builder().dimensions(dimensions)).numberOfTrees(this.numberOfTrees)).shingleSize(shingleSize)).sampleSize(this.rcfSampleSize)).internalShinglingEnabled(true)).precision(Precision.FLOAT_32)).anomalyRate(1.0 - this.thresholdMinPvalue)).outputAfter(Math.max(shingleSize, this.numMinSamples))).calibration(Calibration.MINIMAL).timeDecay(config.getTimeDecay().doubleValue())).parallelExecutionEnabled(false)).boundingBoxCacheFraction(0.0)).transformDecay(config.getTimeDecay().doubleValue())).forecastHorizon(forecastHorizon).initialAcceptFraction(this.initialAcceptFraction)).transformMethod(TransformMethod.NORMALIZE)).forestMode(ForestMode.STANDARD);
        casterBuilder = ForecastColdStart.applyImputationMethod(config, casterBuilder);
        if (this.rcfSeed > 0L) {
            casterBuilder.randomSeed(this.rcfSeed);
        }
        RCFCaster caster = casterBuilder.build();
        ArrayList<Pair> sequentialTime = new ArrayList<Pair>();
        double[][] sequentialData = new double[pointSamples.size()][firstPoint.length];
        long[] timestamps = new long[pointSamples.size()];
        for (int i = 0; i < pointSamples.size(); ++i) {
            Sample dataSample = pointSamples.get(i);
            double[] dataValue = dataSample.getValueList();
            timestamps[i] = dataSample.getDataEndTime().getEpochSecond();
            sequentialData[i] = dataValue;
            sequentialTime.add(Pair.of((Object)dataSample.getDataStartTime(), (Object)dataSample.getDataEndTime()));
        }
        ArrayList<ForecastResult> res = new ArrayList<ForecastResult>();
        try {
            descriptors = caster.processSequentially(sequentialData, timestamps, x -> true);
        }
        catch (Exception e) {
            logger.error("Error while running processSequentially", (Throwable)e);
            return null;
        }
        if (descriptors.size() != sequentialTime.size()) {
            logger.warn((Message)new ParameterizedMessage("processSequentially returns different size than expected, got [{}], expecting [{}].", (Object)descriptors.size(), (Object)sequentialTime.size()));
            return null;
        }
        Instant now = Instant.now();
        for (int i = 0; i < descriptors.size(); ++i) {
            Pair time = (Pair)sequentialTime.get(i);
            ForecastDescriptor descriptor = (ForecastDescriptor)descriptors.get(i);
            double[] dataValue = sequentialData[i];
            RCFCasterResult casterResult = ModelUtil.toResult(caster.getForest(), descriptor, dataValue, false);
            List<ForecastResult> resultI = casterResult.toIndexableResults(config, (Instant)time.getLeft(), (Instant)time.getRight(), now, now, ParseUtils.getFeatureData(dataValue, config), modelState.getEntity(), this.resultMappingVersion, modelState.getModelId(), taskId, null);
            res.addAll(resultI);
        }
        modelState.setModel(caster);
        return res;
    }
}

