/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous.cdi;

import dr.evomodel.treedatalikelihood.continuous.cdi.MultivariateIntegrator;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.InversionResult;
import dr.math.matrixAlgebra.missingData.MissingOps;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class SafeMultivariateIntegrator
extends MultivariateIntegrator {
    private static final boolean DEBUG = false;
    private static final boolean TIMING = false;
    private final int effectiveDimensionOffset;
    private final int determinantOffset;
    private DenseMatrix64F matrixQjPjp;
    private double[] vectorDelta;
    double[] vectorPMk;

    public SafeMultivariateIntegrator(PrecisionType precisionType, int n, int n2, int n3, int n4, int n5) {
        super(precisionType, n, n2, n3, n4, n5);
        this.allocateStorage();
        this.effectiveDimensionOffset = PrecisionType.FULL.getEffectiveDimensionOffset(n2);
        this.determinantOffset = PrecisionType.FULL.getDeterminantOffset(n2);
        System.err.println("Trying SafeMultivariateIntegrator");
    }

    private void allocateStorage() {
        this.precisions = new double[this.dimTrait * this.dimTrait * this.bufferCount];
        this.variances = new double[this.dimTrait * this.dimTrait * this.bufferCount];
        this.vectorDelta = new double[this.dimTrait];
        this.vectorPMk = new double[this.dimTrait];
        this.matrixQjPjp = new DenseMatrix64F(this.dimTrait, this.dimTrait);
    }

    @Override
    public void getBranchPrecision(int n, int n2, double[] dArray) {
        this.getBranchPrecision(n, dArray);
    }

    public void getBranchPrecision(int n, double[] dArray) {
        if (n == -1) {
            throw new RuntimeException("Not yet implemented");
        }
        assert (dArray != null);
        assert (dArray.length >= this.dimTrait * this.dimTrait);
        System.arraycopy(this.precisions, n * this.dimTrait * this.dimTrait, dArray, 0, this.dimTrait * this.dimTrait);
    }

    @Override
    public void getBranchVariance(int n, int n2, double[] dArray) {
        this.getBranchVariance(n, dArray);
    }

    public void getBranchVariance(int n, double[] dArray) {
        if (n == -1) {
            throw new RuntimeException("Not yet implemented");
        }
        assert (dArray != null);
        assert (dArray.length >= this.dimTrait * this.dimTrait);
        System.arraycopy(this.variances, n * this.dimTrait * this.dimTrait, dArray, 0, this.dimTrait * this.dimTrait);
    }

    @Override
    public void getRootPrecision(int n, int n2, double[] dArray) {
        this.getRootPrecision(n, dArray);
    }

    private void getRootPrecision(int n, double[] dArray) {
        assert (dArray != null);
        assert (dArray.length >= this.dimTrait * this.dimTrait);
        int n2 = this.dimPartial * n;
        System.arraycopy(this.partials, n2 + this.dimTrait, dArray, 0, this.dimTrait * this.dimTrait);
    }

    private double getEffectiveDimension(int n) {
        return this.partials[n * this.dimPartial + this.effectiveDimensionOffset];
    }

    private void setEffectiveDimension(int n, double d) {
        this.partials[n * this.dimPartial + this.effectiveDimensionOffset] = d;
    }

    private double getPartialDeterminant(int n) {
        return this.partials[n * this.dimPartial + this.determinantOffset];
    }

    @Override
    public void updateBrownianDiffusionMatrices(int n, int[] nArray, double[] dArray, double[] dArray2, int n2) {
        super.updateBrownianDiffusionMatrices(n, nArray, dArray, dArray2, n2);
        assert (this.diffusions != null);
        assert (nArray.length >= n2);
        assert (dArray.length >= n2);
        int n3 = this.dimProcess * this.dimProcess;
        int n4 = n3 * n;
        for (int i = 0; i < n2; ++i) {
            double d = dArray[i];
            int n5 = n3 * nArray[i];
            SafeMultivariateIntegrator.scale(this.diffusions, n4, 1.0 / d, this.precisions, n5, n3);
            SafeMultivariateIntegrator.scale(this.inverseDiffusions, n4, d, this.variances, n5, n3);
        }
    }

    static void scale(double[] dArray, int n, double d, double[] dArray2, int n2, int n3) {
        for (int i = 0; i < n3; ++i) {
            dArray2[n2 + i] = d * dArray[n + i];
        }
    }

    @Override
    public void updatePreOrderPartial(int n, int n2, int n3, int n4, int n5) {
        int n6 = this.dimPartial * n;
        int n7 = this.dimPartial * n2;
        int n8 = this.dimPartial * n4;
        int n9 = this.dimTrait * this.dimTrait * n3;
        int n10 = this.dimTrait * this.dimTrait * n5;
        int n11 = this.dimTrait * n3;
        int n12 = this.dimTrait * n5;
        DenseMatrix64F denseMatrix64F = MissingOps.wrap(this.variances, n9, this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F2 = MissingOps.wrap(this.variances, n10, this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F3 = MissingOps.wrap(this.precisions, n10, this.dimTrait, this.dimTrait);
        for (int i = 0; i < this.numTraits; ++i) {
            DenseMatrix64F denseMatrix64F4 = MissingOps.wrap(this.preOrderPartials, n6 + this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F5 = this.matrixPjp;
            this.increaseVariances(n8, n4, denseMatrix64F2, denseMatrix64F3, denseMatrix64F5, false);
            DenseMatrix64F denseMatrix64F6 = this.matrixQjPjp;
            this.actualizePrecision(denseMatrix64F5, denseMatrix64F6, n8, n10, n12);
            DenseMatrix64F denseMatrix64F7 = this.matrixPip;
            CommonOps.add((D1Matrix64F)denseMatrix64F4, denseMatrix64F5, (D1Matrix64F)denseMatrix64F7);
            DenseMatrix64F denseMatrix64F8 = this.matrix1;
            MissingOps.safeInvertPrecision(denseMatrix64F7, denseMatrix64F8, false);
            double[] dArray = this.vectorDelta;
            this.computeDelta(n8, n12, dArray);
            MissingOps.safeWeightedAverage(new WrappedVector.Raw(this.preOrderPartials, n6, this.dimTrait), denseMatrix64F4, new WrappedVector.Raw(dArray, 0, this.dimTrait), denseMatrix64F6, new WrappedVector.Raw(this.preOrderPartials, n7, this.dimTrait), denseMatrix64F8, this.dimTrait);
            this.scaleAndDriftMean(n7, n9, n11);
            DenseMatrix64F denseMatrix64F9 = denseMatrix64F8;
            this.actualizeVariance(denseMatrix64F8, n7, n9, n11);
            this.inflateBranch(denseMatrix64F, denseMatrix64F8, denseMatrix64F9);
            DenseMatrix64F denseMatrix64F10 = this.matrixPk;
            MissingOps.safeInvert2(denseMatrix64F9, denseMatrix64F10, false);
            MissingOps.unwrap(denseMatrix64F10, this.preOrderPartials, n7 + this.dimTrait);
            MissingOps.unwrap(denseMatrix64F9, this.preOrderPartials, n7 + this.dimTrait + this.dimTrait * this.dimTrait);
            n6 += this.dimPartialForTrait;
            n7 += this.dimPartialForTrait;
            n8 += this.dimPartialForTrait;
        }
    }

    private void inflateBranch(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        CommonOps.add((D1Matrix64F)denseMatrix64F, denseMatrix64F2, (D1Matrix64F)denseMatrix64F3);
    }

    void actualizePrecision(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, int n, int n2, int n3) {
        CommonOps.scale(1.0, denseMatrix64F, denseMatrix64F2);
    }

    void actualizeVariance(DenseMatrix64F denseMatrix64F, int n, int n2, int n3) {
    }

    void scaleAndDriftMean(int n, int n2, int n3) {
    }

    void computeDelta(int n, int n2, double[] dArray) {
        System.arraycopy(this.partials, n, dArray, 0, this.dimTrait);
    }

    @Override
    protected void updatePartial(int n, int n2, int n3, int n4, int n5, boolean bl, boolean bl2) {
        if (bl2) {
            throw new RuntimeException("Outer-products are not supported.");
        }
        int n6 = this.dimPartial * n;
        int n7 = this.dimPartial * n2;
        int n8 = this.dimPartial * n4;
        int n9 = this.dimTrait * this.dimTrait * n3;
        int n10 = this.dimTrait * this.dimTrait * n5;
        int n11 = this.dimTrait * n3;
        int n12 = this.dimTrait * n5;
        DenseMatrix64F denseMatrix64F = MissingOps.wrap(this.variances, n9, this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F2 = MissingOps.wrap(this.variances, n10, this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F3 = MissingOps.wrap(this.precisions, n9, this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F4 = MissingOps.wrap(this.precisions, n10, this.dimTrait, this.dimTrait);
        for (int i = 0; i < this.numTraits; ++i) {
            DenseMatrix64F denseMatrix64F5 = this.matrixPip;
            DenseMatrix64F denseMatrix64F6 = this.matrixPjp;
            InversionResult inversionResult = this.increaseVariances(n7, n2, denseMatrix64F, denseMatrix64F3, denseMatrix64F5, bl);
            InversionResult inversionResult2 = this.increaseVariances(n8, n4, denseMatrix64F2, denseMatrix64F4, denseMatrix64F6, bl);
            DenseMatrix64F denseMatrix64F7 = this.matrixPk;
            this.computePartialPrecision(n11, n12, n9, n10, denseMatrix64F5, denseMatrix64F6, denseMatrix64F7);
            this.partialMean(n7, n8, n6, n11, n12);
            MissingOps.unwrap(denseMatrix64F7, this.partials, n6 + this.dimTrait);
            double d = 0.0;
            if (bl) {
                double d2;
                if (inversionResult.getReturnCode() != InversionResult.Code.NOT_OBSERVED && inversionResult2.getReturnCode() != InversionResult.Code.NOT_OBSERVED) {
                    d2 = this.computeSS(n7, denseMatrix64F5, n8, denseMatrix64F6, n6, denseMatrix64F7, this.dimTrait);
                    d += -0.5 * d2;
                }
                d2 = this.getEffectiveDimension(n2) + this.getEffectiveDimension(n4);
                d += -d2 * LOG_SQRT_2_PI;
                double d3 = 0.0;
                double d4 = 0.0;
                if (inversionResult.getReturnCode() != InversionResult.Code.NOT_OBSERVED) {
                    d3 = inversionResult.getLogDeterminant();
                }
                if (inversionResult2.getReturnCode() != InversionResult.Code.NOT_OBSERVED) {
                    d4 = inversionResult2.getLogDeterminant();
                }
                d += -0.5 * (d3 + d4);
            }
            this.remainders[n * this.numTraits + i] = d + this.remainders[n2 * this.numTraits + i] + this.remainders[n4 * this.numTraits + i];
            n6 += this.dimPartialForTrait;
            n7 += this.dimPartialForTrait;
            n8 += this.dimPartialForTrait;
        }
    }

    private void reportInversions(InversionResult inversionResult, InversionResult inversionResult2, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
        System.err.println("i status: " + inversionResult);
        System.err.println("j status: " + inversionResult2);
        System.err.println("Pip: " + denseMatrix64F);
        System.err.println("Pjp: " + denseMatrix64F2);
    }

    private InversionResult increaseVariances(int n, int n2, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3, boolean bl) {
        DenseMatrix64F denseMatrix64F4 = MissingOps.wrap(this.partials, n + this.dimTrait, this.dimTrait, this.dimTrait);
        boolean bl2 = MissingOps.anyDiagonalInfinities(denseMatrix64F4);
        InversionResult inversionResult = null;
        if (bl2) {
            DenseMatrix64F denseMatrix64F5 = this.matrix0;
            DenseMatrix64F denseMatrix64F6 = MissingOps.wrap(this.partials, n + this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait, this.dimTrait);
            CommonOps.add((D1Matrix64F)denseMatrix64F6, denseMatrix64F, (D1Matrix64F)denseMatrix64F5);
            if (SafeMultivariateIntegrator.allZeroOrInfinite(denseMatrix64F5)) {
                throw new RuntimeException("Zero-length branch on data is not allowed.");
            }
            inversionResult = MissingOps.safeInvert2(denseMatrix64F5, denseMatrix64F3, bl);
        } else {
            DenseMatrix64F denseMatrix64F7 = this.matrix0;
            CommonOps.add((D1Matrix64F)denseMatrix64F4, denseMatrix64F2, (D1Matrix64F)denseMatrix64F7);
            DenseMatrix64F denseMatrix64F8 = this.matrix1;
            MissingOps.safeInvertPrecision(denseMatrix64F7, denseMatrix64F8, false);
            CommonOps.mult(denseMatrix64F8, denseMatrix64F4, denseMatrix64F7);
            SafeMultivariateIntegrator.idMinusA(denseMatrix64F7);
            if (bl) {
                inversionResult = MissingOps.safeDeterminant(denseMatrix64F7, true);
            }
            CommonOps.mult(denseMatrix64F4, denseMatrix64F7, denseMatrix64F3);
            int n3 = (int)Math.round(this.getEffectiveDimension(n2));
            if (bl && n3 > 0) {
                InversionResult inversionResult2;
                double d = this.getPartialDeterminant(n2);
                if (PrecisionType.FULL.isMissingDeterminantValue(d)) {
                    inversionResult2 = MissingOps.safeDeterminant(denseMatrix64F4, true);
                } else {
                    InversionResult.Code code = InversionResult.getCode(this.dimTrait, n3);
                    inversionResult2 = new InversionResult(code, n3, -d);
                }
                inversionResult = InversionResult.mult(inversionResult, inversionResult2);
            }
        }
        return inversionResult;
    }

    private static void idMinusA(DenseMatrix64F denseMatrix64F) {
        CommonOps.scale(-1.0, denseMatrix64F);
        for (int i = 0; i < denseMatrix64F.numCols; ++i) {
            denseMatrix64F.set(i, i, 1.0 + denseMatrix64F.get(i, i));
        }
    }

    private static boolean allZeroOrInfinite(DenseMatrix64F denseMatrix64F) {
        for (int i = 0; i < denseMatrix64F.getNumElements(); ++i) {
            if (!Double.isFinite(denseMatrix64F.get(i)) || denseMatrix64F.get(i) == 0.0) continue;
            return false;
        }
        return true;
    }

    void computePartialPrecision(int n, int n2, int n3, int n4, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        CommonOps.add((D1Matrix64F)denseMatrix64F, denseMatrix64F2, (D1Matrix64F)denseMatrix64F3);
    }

    void partialMean(int n, int n2, int n3, int n4, int n5) {
        double[] dArray = this.vectorPMk;
        MissingOps.weightedSum(this.partials, n, this.matrixPip, this.partials, n2, this.matrixPjp, this.dimTrait, dArray);
        WrappedVector.Raw raw = new WrappedVector.Raw(this.partials, n3, this.dimTrait);
        WrappedVector.Raw raw2 = new WrappedVector.Raw(dArray, 0, this.dimTrait);
        MissingOps.safeSolve(this.matrixPk, raw2, raw, false);
    }

    @Override
    public void calculateRootLogLikelihood(int n, int n2, int n3, double[] dArray, boolean bl, boolean bl2) {
        assert (dArray.length == this.numTraits);
        assert (!bl);
        this.updatePrecisionOffsetAndDeterminant(n3);
        int n4 = this.dimPartial * n;
        int n5 = this.dimPartial * n2;
        DenseMatrix64F denseMatrix64F = MissingOps.wrap(this.diffusions, this.precisionOffset, this.dimProcess, this.dimProcess);
        for (int i = 0; i < this.numTraits; ++i) {
            DenseMatrix64F denseMatrix64F2;
            DenseMatrix64F denseMatrix64F3;
            DenseMatrix64F denseMatrix64F4 = MissingOps.wrap(this.partials, n5 + this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F5 = MissingOps.wrap(this.partials, n5 + this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait, this.dimTrait);
            if (!bl2) {
                denseMatrix64F3 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
                CommonOps.mult(denseMatrix64F, denseMatrix64F4, denseMatrix64F3);
                denseMatrix64F4.set(denseMatrix64F3);
            } else {
                denseMatrix64F3 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
                MissingOps.blockUnwrap(denseMatrix64F, denseMatrix64F3.data, 0, 0, 0, this.dimTrait);
                MissingOps.blockUnwrap(denseMatrix64F, denseMatrix64F3.data, this.dimProcess, this.dimProcess, 0, this.dimTrait);
                denseMatrix64F2 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
                CommonOps.mult(denseMatrix64F3, denseMatrix64F4, denseMatrix64F2);
                denseMatrix64F4.set(denseMatrix64F2);
            }
            denseMatrix64F3 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
            denseMatrix64F2 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
            CommonOps.invert(denseMatrix64F3, denseMatrix64F2);
            InversionResult inversionResult = this.increaseVariances(n4, n, denseMatrix64F5, denseMatrix64F4, denseMatrix64F2, true);
            double d = MissingOps.weightedInnerProductOfDifferences(this.partials, n4, this.partials, n5, denseMatrix64F2, this.dimTrait);
            double d2 = inversionResult.getReturnCode() == InversionResult.Code.NOT_OBSERVED ? 0.0 : inversionResult.getLogDeterminant();
            double d3 = -0.5 * d2 - 0.5 * d;
            double d4 = this.remainders[n * this.numTraits + i];
            dArray[i] = d3 + d4;
            n4 += this.dimPartialForTrait;
            n5 += this.dimPartialForTrait;
        }
    }

    double computeSS(int n, DenseMatrix64F denseMatrix64F, int n2, DenseMatrix64F denseMatrix64F2, int n3, DenseMatrix64F denseMatrix64F3, int n4) {
        return MissingOps.weightedThreeInnerProductNormalized(this.partials, n, denseMatrix64F, this.partials, n2, denseMatrix64F2, this.partials, n3, this.vectorPMk, 0, n4);
    }
}

