/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.normalization;

import com.google.common.primitives.Floats;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationUtils;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationUtil;
import org.opensearch.neuralsearch.processor.util.ProcessorUtils;

public class ZScoreNormalizationTechnique
implements ScoreNormalizationTechnique,
ExplainableTechnique {
    public static final String TECHNIQUE_NAME = "z_score";
    private static final float SINGLE_RESULT_SCORE = 1.0f;
    private static final float MIN_SCORE = 0.001f;

    @Override
    public void normalize(NormalizeScoresDTO normalizeScoresDTO) {
        List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
        ZScores zscores = this.getZScoreResults(queryTopDocs);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    scoreDoc.score = ZScoreNormalizationTechnique.normalizeSingleScore(scoreDoc.score, zscores.stdPerSubquery[j], zscores.meanPerSubQuery[j], zscores.maxPerSubQuery[j], zscores.minPerSubQuery[j]);
                }
            }
        }
    }

    @Override
    public String techniqueName() {
        return TECHNIQUE_NAME;
    }

    @Override
    public String describe() {
        return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME);
    }

    @Override
    public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
        ZScores zScores = this.getZScoreResults(queryTopDocs);
        HashMap<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<DocIdAtSearchShard, List<Float>>();
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            int numberOfSubQueries = topDocsPerSubQuery.size();
            for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; ++subQueryIndex) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
                    float normalizedScore = ZScoreNormalizationTechnique.normalizeSingleScore(scoreDoc.score, zScores.stdPerSubquery[subQueryIndex], zScores.meanPerSubQuery[subQueryIndex], zScores.maxPerSubQuery[subQueryIndex], zScores.minPerSubQuery[subQueryIndex]);
                    ScoreNormalizationUtil.setNormalizedScore(normalizedScores, docIdAtSearchShard, subQueryIndex, numberOfSubQueries, normalizedScore);
                    scoreDoc.score = normalizedScore;
                }
            }
        }
        return ExplanationUtils.getDocIdAtQueryForNormalization(normalizedScores, this);
    }

    private static float[] calculateMaxScorePerSubquery(List<CompoundTopDocs> queryTopDocs, int numOfSubqueries) {
        DescriptiveStatistics[] statsPerSubquery = ZScoreNormalizationTechnique.calculateStatsPerSubquery(queryTopDocs, numOfSubqueries);
        float[] maxPerSubQuery = new float[numOfSubqueries];
        for (int i = 0; i < numOfSubqueries; ++i) {
            maxPerSubQuery[i] = (float)statsPerSubquery[i].getMax();
        }
        return maxPerSubQuery;
    }

    private static float[] calculateMinScorePerSubquery(List<CompoundTopDocs> queryTopDocs, int numOfSubqueries) {
        DescriptiveStatistics[] statsPerSubquery = ZScoreNormalizationTechnique.calculateStatsPerSubquery(queryTopDocs, numOfSubqueries);
        float[] minPerSubQuery = new float[numOfSubqueries];
        for (int i = 0; i < numOfSubqueries; ++i) {
            minPerSubQuery[i] = (float)statsPerSubquery[i].getMin();
        }
        return minPerSubQuery;
    }

    private static DescriptiveStatistics[] calculateStatsPerSubquery(List<CompoundTopDocs> queryTopDocs, int numOfSubqueries) {
        DescriptiveStatistics[] statsPerSubquery = new DescriptiveStatistics[numOfSubqueries];
        for (int i = 0; i < numOfSubqueries; ++i) {
            statsPerSubquery[i] = new DescriptiveStatistics();
        }
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int subQueryIndex = 0; subQueryIndex < topDocsPerSubQuery.size(); ++subQueryIndex) {
                TopDocs topDocs = topDocsPerSubQuery.get(subQueryIndex);
                for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
                    statsPerSubquery[subQueryIndex].addValue((double)scoreDoc.score);
                }
            }
        }
        return statsPerSubquery;
    }

    private static float[] calculateMeanPerSubquery(List<CompoundTopDocs> queryTopDocs, int numOfSubqueries) {
        DescriptiveStatistics[] statsPerSubquery = ZScoreNormalizationTechnique.calculateStatsPerSubquery(queryTopDocs, numOfSubqueries);
        float[] meanPerSubQuery = new float[numOfSubqueries];
        for (int i = 0; i < numOfSubqueries; ++i) {
            meanPerSubQuery[i] = (float)statsPerSubquery[i].getMean();
        }
        return meanPerSubQuery;
    }

    private static float[] calculateStandardDeviationPerSubquery(List<CompoundTopDocs> queryTopDocs, int numOfSubqueries) {
        DescriptiveStatistics[] statsPerSubquery = ZScoreNormalizationTechnique.calculateStatsPerSubquery(queryTopDocs, numOfSubqueries);
        float[] stdPerSubQuery = new float[numOfSubqueries];
        for (int i = 0; i < numOfSubqueries; ++i) {
            stdPerSubQuery[i] = (float)statsPerSubquery[i].getStandardDeviation();
        }
        return stdPerSubQuery;
    }

    private ZScores getZScoreResults(List<CompoundTopDocs> queryTopDocs) {
        int numOfSubqueries = ProcessorUtils.getNumOfSubqueries(queryTopDocs);
        float[] maxPerSubquery = ZScoreNormalizationTechnique.calculateMaxScorePerSubquery(queryTopDocs, numOfSubqueries);
        float[] minPerSubquery = ZScoreNormalizationTechnique.calculateMinScorePerSubquery(queryTopDocs, numOfSubqueries);
        float[] meanPerSubQuery = ZScoreNormalizationTechnique.calculateMeanPerSubquery(queryTopDocs, numOfSubqueries);
        float[] stdPerSubquery = ZScoreNormalizationTechnique.calculateStandardDeviationPerSubquery(queryTopDocs, numOfSubqueries);
        return new ZScores(meanPerSubQuery, stdPerSubquery, maxPerSubquery, minPerSubquery);
    }

    private static float normalizeSingleScore(float score, float standardDeviation, float mean, float maxScore, float minScore) {
        if (Floats.compare((float)mean, (float)score) == 0) {
            return maxScore;
        }
        if (Floats.compare((float)standardDeviation, (float)0.0f) == 0) {
            return minScore;
        }
        float normalizedScore = (score - mean) / standardDeviation;
        return normalizedScore <= 0.0f ? 0.001f : normalizedScore;
    }

    @Generated
    public String toString() {
        return "ZScoreNormalizationTechnique(TECHNIQUE_NAME=z_score)";
    }

    private record ZScores(float[] meanPerSubQuery, float[] stdPerSubquery, float[] maxPerSubQuery, float[] minPerSubQuery) {
    }
}

