/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.Comparator;
import java.util.Objects;
import java.util.PriorityQueue;
import org.elasticsearch.search.aggregations.pipeline.MovingFunctions;

public final class NlpHelpers {
    private NlpHelpers() {
    }

    static double[][] convertToProbabilitiesBySoftMax(double[][] scores) {
        int i;
        double[][] probabilities = new double[scores.length][];
        double[] sum = new double[scores.length];
        for (i = 0; i < scores.length; ++i) {
            probabilities[i] = new double[scores[i].length];
            double maxScore = MovingFunctions.max((double[])scores[i]);
            for (int j = 0; j < scores[i].length; ++j) {
                probabilities[i][j] = Math.exp(scores[i][j] - maxScore);
                int n = i;
                sum[n] = sum[n] + probabilities[i][j];
            }
        }
        for (i = 0; i < scores.length; ++i) {
            int j = 0;
            while (j < scores[i].length) {
                double[] dArray = probabilities[i];
                int n = j++;
                dArray[n] = dArray[n] / sum[i];
            }
        }
        return probabilities;
    }

    static double[] convertToProbabilitiesBySoftMax(double[] scores) {
        int i;
        double[] probabilities = new double[scores.length];
        double sum = 0.0;
        double maxScore = MovingFunctions.max((double[])scores);
        for (i = 0; i < scores.length; ++i) {
            probabilities[i] = Math.exp(scores[i] - maxScore);
            sum += probabilities[i];
        }
        i = 0;
        while (i < scores.length) {
            int n = i++;
            probabilities[n] = probabilities[n] / sum;
        }
        return probabilities;
    }

    static int argmax(double[] arr) {
        int maxIndex = 0;
        for (int i = 1; i < arr.length; ++i) {
            if (!(arr[i] > arr[maxIndex])) continue;
            maxIndex = i;
        }
        return maxIndex;
    }

    static ScoreAndIndex[] topK(int k, double[] arr) {
        if (k > arr.length) {
            k = arr.length;
        }
        PriorityQueue<ScoreAndIndex> minHeap = new PriorityQueue<ScoreAndIndex>(k, Comparator.comparingDouble(o -> o.score));
        for (int i = 0; i < k; ++i) {
            minHeap.add(new ScoreAndIndex(arr[i], i));
        }
        double minValue = minHeap.peek().score;
        for (int i = k; i < arr.length; ++i) {
            if (!(arr[i] > minValue)) continue;
            minHeap.poll();
            minHeap.add(new ScoreAndIndex(arr[i], i));
            minValue = minHeap.peek().score;
        }
        ScoreAndIndex[] result = new ScoreAndIndex[k];
        for (int i = k - 1; i >= 0; --i) {
            result[i] = minHeap.poll();
        }
        return result;
    }

    public static class ScoreAndIndex {
        final double score;
        final int index;

        ScoreAndIndex(double value, int index) {
            this.score = value;
            this.index = index;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ScoreAndIndex that = (ScoreAndIndex)o;
            return Double.compare(that.score, this.score) == 0 && this.index == that.index;
        }

        public int hashCode() {
            return Objects.hash(this.score, this.index);
        }
    }
}

