package org.apache.lucene.util.quantization;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;

/* loaded from: input_file:org/apache/lucene/util/quantization/OptimizedScalarQuantizer.class */
public class OptimizedScalarQuantizer {
    static final float[][] MINIMUM_MSE_GRID;
    private static final float DEFAULT_LAMBDA = 0.1f;
    private static final int DEFAULT_ITERS = 5;
    private final VectorSimilarityFunction similarityFunction;
    private final float lambda;
    private final int iters;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult.class */
    public static final class QuantizationResult extends Record {
        private final float lowerInterval;
        private final float upperInterval;
        private final float additionalCorrection;
        private final int quantizedComponentSum;

        public QuantizationResult(float f, float f2, float f3, int i) {
            this.lowerInterval = f;
            this.upperInterval = f2;
            this.additionalCorrection = f3;
            this.quantizedComponentSum = i;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, QuantizationResult.class), QuantizationResult.class, "lowerInterval;upperInterval;additionalCorrection;quantizedComponentSum", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->lowerInterval:F", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->upperInterval:F", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->additionalCorrection:F", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->quantizedComponentSum:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, QuantizationResult.class), QuantizationResult.class, "lowerInterval;upperInterval;additionalCorrection;quantizedComponentSum", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->lowerInterval:F", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->upperInterval:F", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->additionalCorrection:F", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->quantizedComponentSum:I").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, QuantizationResult.class, Object.class), QuantizationResult.class, "lowerInterval;upperInterval;additionalCorrection;quantizedComponentSum", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->lowerInterval:F", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->upperInterval:F", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->additionalCorrection:F", "FIELD:Lorg/apache/lucene/util/quantization/OptimizedScalarQuantizer$QuantizationResult;->quantizedComponentSum:I").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public float lowerInterval() {
            return this.lowerInterval;
        }

        public float upperInterval() {
            return this.upperInterval;
        }

        public float additionalCorrection() {
            return this.additionalCorrection;
        }

        public int quantizedComponentSum() {
            return this.quantizedComponentSum;
        }
    }

    public OptimizedScalarQuantizer(VectorSimilarityFunction vectorSimilarityFunction, float f, int i) {
        this.similarityFunction = vectorSimilarityFunction;
        this.lambda = f;
        this.iters = i;
    }

    public OptimizedScalarQuantizer(VectorSimilarityFunction vectorSimilarityFunction) {
        this(vectorSimilarityFunction, DEFAULT_LAMBDA, 5);
    }

    public QuantizationResult[] multiScalarQuantize(float[] fArr, byte[][] bArr, byte[] bArr2, float[] fArr2) {
        if (!$assertionsDisabled && this.similarityFunction == VectorSimilarityFunction.COSINE && !VectorUtil.isUnitVector(fArr)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.similarityFunction == VectorSimilarityFunction.COSINE && !VectorUtil.isUnitVector(fArr2)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && bArr2.length != bArr.length) {
            throw new AssertionError();
        }
        float[] fArr3 = new float[2];
        double d = 0.0d;
        double d2 = 0.0d;
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = Float.MAX_VALUE;
        float f4 = -3.4028235E38f;
        for (int i = 0; i < fArr.length; i++) {
            if (this.similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
                f2 += fArr[i] * fArr2[i];
            }
            fArr[i] = fArr[i] - fArr2[i];
            f3 = Math.min(f3, fArr[i]);
            f4 = Math.max(f4, fArr[i]);
            f += fArr[i] * fArr[i];
            double d3 = fArr[i] - d;
            d += d3 / (i + 1);
            d2 += d3 * (fArr[i] - d);
        }
        double sqrt = Math.sqrt(d2 / fArr.length);
        QuantizationResult[] quantizationResultArr = new QuantizationResult[bArr2.length];
        for (int i2 = 0; i2 < bArr2.length; i2++) {
            if (!$assertionsDisabled && (bArr2[i2] <= 0 || bArr2[i2] > 8)) {
                throw new AssertionError();
            }
            int i3 = 1 << bArr2[i2];
            fArr3[0] = (float) clamp((MINIMUM_MSE_GRID[bArr2[i2] - 1][0] * sqrt) + d, f3, f4);
            fArr3[1] = (float) clamp((MINIMUM_MSE_GRID[bArr2[i2] - 1][1] * sqrt) + d, f3, f4);
            optimizeIntervals(fArr3, fArr, f, i3);
            float f5 = (1 << bArr2[i2]) - 1;
            float f6 = fArr3[0];
            float f7 = fArr3[1];
            float f8 = (f7 - f6) / f5;
            int i4 = 0;
            for (int i5 = 0; i5 < fArr.length; i5++) {
                int round = Math.round((((float) clamp(fArr[i5], f6, f7)) - f6) / f8);
                i4 += round;
                bArr[i2][i5] = (byte) round;
            }
            quantizationResultArr[i2] = new QuantizationResult(fArr3[0], fArr3[1], this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN ? f : f2, i4);
        }
        return quantizationResultArr;
    }

    public QuantizationResult scalarQuantize(float[] fArr, byte[] bArr, byte b, float[] fArr2) {
        if (!$assertionsDisabled && this.similarityFunction == VectorSimilarityFunction.COSINE && !VectorUtil.isUnitVector(fArr)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.similarityFunction == VectorSimilarityFunction.COSINE && !VectorUtil.isUnitVector(fArr2)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && fArr.length > bArr.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && (b <= 0 || b > 8)) {
            throw new AssertionError();
        }
        float[] fArr3 = new float[2];
        int i = 1 << b;
        double d = 0.0d;
        double d2 = 0.0d;
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = Float.MAX_VALUE;
        float f4 = -3.4028235E38f;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (this.similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
                f2 += fArr[i2] * fArr2[i2];
            }
            fArr[i2] = fArr[i2] - fArr2[i2];
            f3 = Math.min(f3, fArr[i2]);
            f4 = Math.max(f4, fArr[i2]);
            f += fArr[i2] * fArr[i2];
            double d3 = fArr[i2] - d;
            d += d3 / (i2 + 1);
            d2 += d3 * (fArr[i2] - d);
        }
        double sqrt = Math.sqrt(d2 / fArr.length);
        fArr3[0] = (float) clamp((MINIMUM_MSE_GRID[b - 1][0] * sqrt) + d, f3, f4);
        fArr3[1] = (float) clamp((MINIMUM_MSE_GRID[b - 1][1] * sqrt) + d, f3, f4);
        optimizeIntervals(fArr3, fArr, f, i);
        float f5 = fArr3[0];
        float f6 = fArr3[1];
        float f7 = (f6 - f5) / ((1 << b) - 1);
        int i3 = 0;
        for (int i4 = 0; i4 < fArr.length; i4++) {
            int round = Math.round((((float) clamp(fArr[i4], f5, f6)) - f5) / f7);
            i3 += round;
            bArr[i4] = (byte) round;
        }
        return new QuantizationResult(fArr3[0], fArr3[1], this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN ? f : f2, i3);
    }

    private double loss(float[] fArr, float[] fArr2, int i, float f) {
        double d = fArr2[0];
        double d2 = (fArr2[1] - d) / (i - 1.0f);
        double d3 = 1.0d / d2;
        double d4 = 0.0d;
        double d5 = 0.0d;
        for (double d6 : fArr) {
            double round = d + (d2 * Math.round((clamp(d6, d, r0) - d) * d3));
            d4 += d6 * (d6 - round);
            d5 += (d6 - round) * (d6 - round);
        }
        return ((((1.0d - this.lambda) * d4) * d4) / f) + (this.lambda * d5);
    }

    private void optimizeIntervals(float[] fArr, float[] fArr2, float f, int i) {
        double loss = loss(fArr2, fArr, i, f);
        float f2 = (1.0f - this.lambda) / f;
        if (Float.isFinite(f2)) {
            for (int i2 = 0; i2 < this.iters; i2++) {
                float f3 = fArr[0];
                float f4 = fArr[1];
                float f5 = (i - 1.0f) / (f4 - f3);
                double d = 0.0d;
                double d2 = 0.0d;
                double d3 = 0.0d;
                double d4 = 0.0d;
                double d5 = 0.0d;
                for (float f6 : fArr2) {
                    float round = ((float) Math.round((clamp(f6, f3, f4) - f3) * f5)) / (i - 1);
                    d += (1.0d - round) * (1.0d - round);
                    d2 += (1.0d - round) * round;
                    d3 += round * round;
                    d4 += f6 * (1.0d - round);
                    d5 += f6 * round;
                }
                double d6 = (f2 * d4 * d4) + (this.lambda * d);
                double d7 = (f2 * d4 * d5) + (this.lambda * d2);
                double d8 = (f2 * d5 * d5) + (this.lambda * d3);
                double d9 = (d6 * d8) - (d7 * d7);
                if (d9 == 0.0d) {
                    return;
                }
                float f7 = (float) (((d8 * d4) - (d7 * d5)) / d9);
                float f8 = (float) (((d6 * d5) - (d7 * d4)) / d9);
                if (Math.abs(fArr[0] - f7) < 1.0E-8d && Math.abs(fArr[1] - f8) < 1.0E-8d) {
                    return;
                }
                double loss2 = loss(fArr2, new float[]{f7, f8}, i, f);
                if (loss2 > loss) {
                    return;
                }
                fArr[0] = f7;
                fArr[1] = f8;
                loss = loss2;
            }
        }
    }

    public static int discretize(int i, int i2) {
        return ((i + (i2 - 1)) / i2) * i2;
    }

    public static void transposeHalfByte(byte[] bArr, byte[] bArr2) {
        int i = 0;
        while (i < bArr.length) {
            if (!$assertionsDisabled && (bArr[i] < 0 || bArr[i] > 15)) {
                throw new AssertionError();
            }
            int i2 = 0;
            int i3 = 0;
            int i4 = 0;
            int i5 = 0;
            for (int i6 = 7; i6 >= 0 && i < bArr.length; i6--) {
                i2 |= (bArr[i] & 1) << i6;
                i3 |= ((bArr[i] >> 1) & 1) << i6;
                i4 |= ((bArr[i] >> 2) & 1) << i6;
                i5 |= ((bArr[i] >> 3) & 1) << i6;
                i++;
            }
            int i7 = ((i + 7) / 8) - 1;
            bArr2[i7] = (byte) i2;
            bArr2[i7 + (bArr2.length / 4)] = (byte) i3;
            bArr2[i7 + (bArr2.length / 2)] = (byte) i4;
            bArr2[i7 + ((3 * bArr2.length) / 4)] = (byte) i5;
        }
    }

    public static void packAsBinary(byte[] bArr, byte[] bArr2) {
        int i = 0;
        while (i < bArr.length) {
            byte b = 0;
            for (int i2 = 7; i2 >= 0 && i < bArr.length; i2--) {
                if (!$assertionsDisabled && bArr[i] != 0 && bArr[i] != 1) {
                    throw new AssertionError();
                }
                b = (byte) (b | ((byte) ((bArr[i] & 1) << i2)));
                i++;
            }
            int i3 = ((i + 7) / 8) - 1;
            if (!$assertionsDisabled && i3 >= bArr2.length) {
                throw new AssertionError();
            }
            bArr2[i3] = b;
        }
    }

    private static double clamp(double d, double d2, double d3) {
        return Math.min(Math.max(d, d2), d3);
    }

    /* JADX WARN: Type inference failed for: r0v5, types: [float[], float[][]] */
    static {
        $assertionsDisabled = !OptimizedScalarQuantizer.class.desiredAssertionStatus();
        MINIMUM_MSE_GRID = new float[]{new float[]{-0.798f, 0.798f}, new float[]{-1.493f, 1.493f}, new float[]{-2.051f, 2.051f}, new float[]{-2.514f, 2.514f}, new float[]{-2.916f, 2.916f}, new float[]{-3.278f, 3.278f}, new float[]{-3.611f, 3.611f}, new float[]{-3.922f, 3.922f}};
    }
}
