/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Explanation;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchResponseProcessor;

public class ExplanationResponseProcessor
implements SearchResponseProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(ExplanationResponseProcessor.class);
    public static final String TYPE = "hybrid_score_explanation";
    private final String description;
    private final String tag;
    private final boolean ignoreFailure;

    public SearchResponse processResponse(SearchRequest request, SearchResponse response) {
        return this.processResponse(request, response, null);
    }

    public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) {
        if (Objects.isNull(requestContext) || Objects.isNull(requestContext.getAttribute("explanation_response")) || !(requestContext.getAttribute("explanation_response") instanceof ExplanationPayload)) {
            return response;
        }
        ExplanationPayload explanationPayload = (ExplanationPayload)requestContext.getAttribute("explanation_response");
        Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload();
        if (explainPayload.containsKey((Object)ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR)) {
            SearchHits searchHits = response.getHits();
            SearchHit[] searchHitsArray = searchHits.getHits();
            HashMap<SearchShard, List> searchHitsByShard = new HashMap<SearchShard, List>();
            HashMap<SearchShard, Integer> explainsByShardCount = new HashMap<SearchShard, Integer>();
            for (int i = 0; i < searchHitsArray.length; ++i) {
                SearchHit searchHit = searchHitsArray[i];
                SearchShardTarget searchShardTarget = searchHit.getShard();
                SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
                searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList()).add(i);
                explainsByShardCount.putIfAbsent(searchShard, -1);
            }
            if (explainPayload.get((Object)ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR) instanceof Map) {
                Map combinedExplainDetails = (Map)explainPayload.get((Object)ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR);
                for (SearchHit searchHit : searchHitsArray) {
                    SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
                    int explanationIndexByShard = (Integer)explainsByShardCount.get(searchShard) + 1;
                    CombinedExplanationDetails combinedExplainDetail = (CombinedExplanationDetails)((List)combinedExplainDetails.get(searchShard)).get(explanationIndexByShard);
                    Explanation queryLevelExplanation = searchHit.getExplanation();
                    ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
                    ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
                    if (normalizationExplanation.getScoreDetails().size() != queryLevelExplanation.getDetails().length) {
                        log.error(String.format(Locale.ROOT, "length of query level explanations %d must match length of explanations after normalization %d", queryLevelExplanation.getDetails().length, normalizationExplanation.getScoreDetails().size()));
                        throw new IllegalStateException("mismatch in number of query level explanations and normalization explanations");
                    }
                    ArrayList<Explanation> normalizedExplanation = new ArrayList<Explanation>(queryLevelExplanation.getDetails().length);
                    int normalizationExplanationIndex = 0;
                    for (Explanation queryExplanation : queryLevelExplanation.getDetails()) {
                        if (Float.compare(queryExplanation.getValue().floatValue(), 0.0f) > 0) {
                            Pair<Float, String> normalizedScoreDetails = normalizationExplanation.getScoreDetails().get(normalizationExplanationIndex);
                            if (Objects.isNull(normalizedScoreDetails)) {
                                throw new IllegalStateException("normalized score details must not be null");
                            }
                            normalizedExplanation.add(Explanation.match((Number)((Number)normalizedScoreDetails.getKey()), (String)((String)normalizedScoreDetails.getValue()), (Explanation[])new Explanation[]{queryExplanation}));
                        }
                        ++normalizationExplanationIndex;
                    }
                    Float finalScore = Float.valueOf(Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore());
                    Explanation finalExplanation = Explanation.match((Number)finalScore, (String)((String)combinationExplanation.getScoreDetails().get(0).getValue()), normalizedExplanation);
                    searchHit.explanation(finalExplanation);
                    explainsByShardCount.put(searchShard, explanationIndexByShard);
                }
            }
        }
        return response;
    }

    public String getType() {
        return TYPE;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }

    @Generated
    public String getTag() {
        return this.tag;
    }

    @Generated
    public boolean isIgnoreFailure() {
        return this.ignoreFailure;
    }

    @Generated
    public ExplanationResponseProcessor(String description, String tag, boolean ignoreFailure) {
        this.description = description;
        this.tag = tag;
        this.ignoreFailure = ignoreFailure;
    }
}

