/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflowExecuteRequest;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflowUtil;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.collapse.CollapseDTO;
import org.opensearch.neuralsearch.processor.collapse.CollapseExecutor;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.util.HybridSearchSortUtil;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.query.QuerySearchResult;

public class NormalizationProcessorWorkflow {
    @Generated
    private static final Logger log = LogManager.getLogger(NormalizationProcessorWorkflow.class);
    private final ScoreNormalizer scoreNormalizer;
    private final ScoreCombiner scoreCombiner;

    public void execute(NormalizationProcessorWorkflowExecuteRequest request) {
        List<QuerySearchResult> querySearchResults = request.getQuerySearchResults();
        Optional<FetchSearchResult> fetchSearchResultOptional = request.getFetchSearchResultOptional();
        List<Integer> unprocessedDocIds = this.unprocessedDocIds(querySearchResults);
        log.debug("Pre-process query results");
        List<CompoundTopDocs> queryTopDocs = this.getQueryTopDocs(querySearchResults);
        this.explain(request, queryTopDocs);
        NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder().queryTopDocs(queryTopDocs).normalizationTechnique(request.getNormalizationTechnique()).build();
        log.debug("Do score normalization");
        this.scoreNormalizer.normalizeScores(normalizeScoresDTO);
        CombineScoresDto combineScoresDTO = CombineScoresDto.builder().queryTopDocs(queryTopDocs).scoreCombinationTechnique(request.getCombinationTechnique()).querySearchResults(querySearchResults).sort(HybridSearchSortUtil.evaluateSortCriteria(querySearchResults, queryTopDocs)).fromValueForSingleShard(this.getFromValueIfSingleShard(request)).isSingleShard(this.getIsSingleShard(request)).build();
        log.debug("Do score combination");
        this.scoreCombiner.combineScores(combineScoresDTO);
        log.debug("Post-process query results after score normalization and combination");
        this.updateOriginalQueryResults(combineScoresDTO, fetchSearchResultOptional.isPresent());
        this.updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds, combineScoresDTO.getFromValueForSingleShard());
    }

    private boolean getIsSingleShard(NormalizationProcessorWorkflowExecuteRequest request) {
        SearchPhaseContext searchPhaseContext = request.getSearchPhaseContext();
        return searchPhaseContext.getNumShards() == 1 || !request.fetchSearchResultOptional.isEmpty();
    }

    private int getFromValueIfSingleShard(NormalizationProcessorWorkflowExecuteRequest request) {
        SearchPhaseContext searchPhaseContext = request.getSearchPhaseContext();
        if (!this.getIsSingleShard(request)) {
            return -1;
        }
        int from = searchPhaseContext.getRequest().source().from();
        if (from == -1) {
            return 0;
        }
        return from;
    }

    private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<CompoundTopDocs> queryTopDocs) {
        if (!request.isExplain()) {
            return;
        }
        if (Objects.nonNull(request.getPipelineProcessingContext())) {
            Sort sortForQuery = HybridSearchSortUtil.evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs);
            Map<DocIdAtSearchShard, ExplanationDetails> normalizationExplain = this.scoreNormalizer.explain(queryTopDocs, (ExplainableTechnique)((Object)request.getNormalizationTechnique()));
            Map<SearchShard, List<ExplanationDetails>> combinationExplain = this.scoreCombiner.explain(queryTopDocs, request.getCombinationTechnique(), sortForQuery);
            HashMap combinedExplanations = new HashMap();
            for (Map.Entry<SearchShard, List<ExplanationDetails>> entry : combinationExplain.entrySet()) {
                ArrayList<CombinedExplanationDetails> combinedDetailsList = new ArrayList<CombinedExplanationDetails>();
                for (ExplanationDetails explainDetail : entry.getValue()) {
                    DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.getDocId(), entry.getKey());
                    CombinedExplanationDetails combinedDetail = CombinedExplanationDetails.builder().normalizationExplanations(normalizationExplain.get(docIdAtSearchShard)).combinationExplanations(explainDetail).build();
                    combinedDetailsList.add(combinedDetail);
                }
                combinedExplanations.put(entry.getKey(), combinedDetailsList);
            }
            ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplanations)).build();
            PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
            pipelineProcessingContext.setAttribute("explanation_response", (Object)explanationPayload);
        }
    }

    private List<CompoundTopDocs> getQueryTopDocs(List<QuerySearchResult> querySearchResults) {
        List<CompoundTopDocs> queryTopDocs = querySearchResults.stream().filter(searchResult -> Objects.nonNull(searchResult.topDocs())).map(CompoundTopDocs::new).collect(Collectors.toList());
        if (queryTopDocs.size() != querySearchResults.size()) {
            throw new IllegalStateException(String.format(Locale.ROOT, "query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match", querySearchResults.size(), queryTopDocs.size()));
        }
        return queryTopDocs;
    }

    private void updateOriginalQueryResults(CombineScoresDto combineScoresDTO, boolean isFetchPhaseExecuted) {
        List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
        List<CompoundTopDocs> queryTopDocs = this.getCompoundTopDocs(combineScoresDTO, querySearchResults);
        Sort sort = combineScoresDTO.getSort();
        int totalScoreDocsCount = 0;
        boolean isCollapseEnabled = false;
        int firstNonEmptyIndex = -1;
        for (int queryTopDocIndex = 0; queryTopDocIndex < queryTopDocs.size(); ++queryTopDocIndex) {
            List<TopDocs> topDocsList = queryTopDocs.get(queryTopDocIndex).getTopDocs();
            if (topDocsList.isEmpty() || !(topDocsList.getFirst() instanceof CollapseTopFieldDocs)) continue;
            isCollapseEnabled = true;
            firstNonEmptyIndex = queryTopDocIndex;
            break;
        }
        if (isCollapseEnabled) {
            CollapseExecutor collapseExecutor = new CollapseExecutor();
            CollapseDTO collapseDTO = new CollapseDTO(queryTopDocs, querySearchResults, sort, firstNonEmptyIndex, isFetchPhaseExecuted, combineScoresDTO);
            totalScoreDocsCount = collapseExecutor.executeCollapse(collapseDTO);
        } else {
            for (int shardIndex = 0; shardIndex < querySearchResults.size(); ++shardIndex) {
                QuerySearchResult querySearchResult = querySearchResults.get(shardIndex);
                CompoundTopDocs updatedTopDocs = queryTopDocs.get(shardIndex);
                totalScoreDocsCount += updatedTopDocs.getScoreDocs().size();
                TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(this.buildTopDocs(updatedTopDocs, sort), NormalizationProcessorWorkflowUtil.maxScoreForShard(updatedTopDocs, sort != null));
                if (isFetchPhaseExecuted) {
                    querySearchResult.from(combineScoresDTO.getFromValueForSingleShard());
                }
                querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
            }
        }
        int from = querySearchResults.get(0).from();
        if (from > totalScoreDocsCount) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results", new Object[0]));
        }
    }

    private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) {
        List<CompoundTopDocs> queryTopDocs = combineScoresDTO.getQueryTopDocs();
        if (querySearchResults.size() != queryTopDocs.size()) {
            throw new IllegalStateException(String.format(Locale.ROOT, "query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match", querySearchResults.size(), queryTopDocs.size()));
        }
        return queryTopDocs;
    }

    private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
        if (sort != null) {
            return new TopFieldDocs(updatedTopDocs.getTotalHits(), (ScoreDoc[])updatedTopDocs.getScoreDocs().toArray(new FieldDoc[0]), sort.getSort());
        }
        return new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0]));
    }

    private void updateOriginalFetchResults(List<QuerySearchResult> querySearchResults, Optional<FetchSearchResult> fetchSearchResultOptional, List<Integer> docIds, int fromValueForSingleShard) {
        if (fetchSearchResultOptional.isEmpty()) {
            return;
        }
        FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
        boolean requestCache = Objects.nonNull(querySearchResults) && !querySearchResults.isEmpty() && Objects.nonNull(querySearchResults.get(0).getShardSearchRequest().requestCache()) && querySearchResults.get(0).getShardSearchRequest().requestCache() != false;
        SearchHit[] searchHitArray = this.getSearchHits(docIds, fetchSearchResult, requestCache);
        HashMap<Integer, SearchHit> docIdToSearchHit = new HashMap<Integer, SearchHit>();
        for (int i = 0; i < searchHitArray.length; ++i) {
            int originalDocId = docIds.get(i);
            docIdToSearchHit.put(originalDocId, searchHitArray[i]);
        }
        QuerySearchResult querySearchResult = querySearchResults.get(0);
        TopDocs topDocs = querySearchResult.topDocs().topDocs;
        int trimmedLengthOfSearchHits = topDocs.scoreDocs.length - fromValueForSingleShard;
        SearchHit[] updatedSearchHitArray = new SearchHit[trimmedLengthOfSearchHits];
        for (int i = 0; i < trimmedLengthOfSearchHits; ++i) {
            ScoreDoc scoreDoc = topDocs.scoreDocs[i + fromValueForSingleShard];
            SearchHit searchHit = (SearchHit)docIdToSearchHit.get(scoreDoc.doc);
            searchHit.score(scoreDoc.score);
            updatedSearchHitArray[i] = searchHit;
        }
        SearchHits updatedSearchHits = new SearchHits(updatedSearchHitArray, querySearchResult.getTotalHits(), querySearchResult.getMaxScore());
        fetchSearchResult.hits(updatedSearchHits);
    }

    private SearchHit[] getSearchHits(List<Integer> docIds, FetchSearchResult fetchSearchResult, boolean requestCache) {
        SearchHits searchHits = fetchSearchResult.hits();
        SearchHit[] searchHitArray = searchHits.getHits();
        if (Objects.isNull(searchHitArray)) {
            throw new IllegalStateException("score normalization processor cannot produce final query result, fetch query phase returns empty results");
        }
        if (!requestCache && searchHitArray.length != docIds.size() || requestCache && docIds.size() < searchHitArray.length) {
            throw new IllegalStateException(String.format(Locale.ROOT, "score normalization processor cannot produce final query result, the number of documents after fetch phase [%d] is different from number of documents from query phase [%d]", searchHitArray.length, docIds.size()));
        }
        return searchHitArray;
    }

    private List<Integer> unprocessedDocIds(List<QuerySearchResult> querySearchResults) {
        List<Integer> docIds = querySearchResults.isEmpty() ? List.of() : Arrays.stream(querySearchResults.get((int)0).topDocs().topDocs.scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList());
        return docIds;
    }

    @Generated
    public NormalizationProcessorWorkflow(ScoreNormalizer scoreNormalizer, ScoreCombiner scoreCombiner) {
        this.scoreNormalizer = scoreNormalizer;
        this.scoreCombiner = scoreCombiner;
    }
}

