/*
 * Decompiled with CFR 0.152.
 */
package dr.geo.distributions;

import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood;
import dr.inference.operators.factorAnalysis.FactorAnalysisOperatorAdaptor;
import dr.math.MathUtils;
import dr.math.ModifiedBesselFirstKind;
import dr.math.distributions.MultivariateDistribution;
import dr.math.distributions.RandomGenerator;
import dr.math.matrixAlgebra.Vector;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.DecompositionFactory;
import org.ejml.interfaces.decomposition.EigenDecomposition;
import org.ejml.interfaces.decomposition.SingularValueDecomposition;
import org.ejml.ops.CommonOps;

public class MatrixVonMisesFisherDistribution
implements RandomGenerator,
MultivariateDistribution,
Reportable {
    private final FactorAnalysisOperatorAdaptor adaptor;
    private final DenseMatrix64F C;
    private final int nRows;
    private final int nColumns;
    private final DenseMatrix64F mkBuffer1;
    private final DenseMatrix64F mkBuffer2;
    private final DenseMatrix64F kkBuffer1;
    private final DenseMatrix64F kkBuffer2;
    private final DenseMatrix64F kkBuffer3;
    private final DenseMatrix64F mmBuffer;
    private final DenseMatrix64F mBuffer1;
    private final DenseMatrix64F mBuffer2;
    private final DenseMatrix64F D;
    private final DenseMatrix64F V;
    private final DenseMatrix64F F;
    private final DenseMatrix64F Y;
    private final DenseMatrix64F H;
    private static final int MAX_REJECTS = 100;
    public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private static final String MATRIX_VON_MISES_FISHER_DISTRIBUTION = "matrixVonMisesFisherDistribution";

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood)xMLObject.getChild(TreeDataLikelihood.class);
            IntegratedFactorAnalysisLikelihood integratedFactorAnalysisLikelihood = (IntegratedFactorAnalysisLikelihood)xMLObject.getChild(IntegratedFactorAnalysisLikelihood.class);
            FactorAnalysisOperatorAdaptor.IntegratedFactors integratedFactors = new FactorAnalysisOperatorAdaptor.IntegratedFactors(integratedFactorAnalysisLikelihood, treeDataLikelihood);
            return new MatrixVonMisesFisherDistribution(integratedFactors);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return new XMLSyntaxRule[]{new ElementRule(TreeDataLikelihood.class), new ElementRule(IntegratedFactorAnalysisLikelihood.class)};
        }

        @Override
        public String getParserDescription() {
            return null;
        }

        @Override
        public Class getReturnType() {
            return MatrixVonMisesFisherDistribution.class;
        }

        @Override
        public String getParserName() {
            return MATRIX_VON_MISES_FISHER_DISTRIBUTION;
        }
    };

    public MatrixVonMisesFisherDistribution(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor) {
        this.adaptor = factorAnalysisOperatorAdaptor;
        this.nRows = factorAnalysisOperatorAdaptor.getNumberOfTraits();
        this.nColumns = factorAnalysisOperatorAdaptor.getNumberOfFactors();
        this.C = new DenseMatrix64F(this.nRows, this.nColumns);
        this.mkBuffer1 = new DenseMatrix64F(this.nRows, this.nColumns);
        this.mkBuffer2 = new DenseMatrix64F(this.nRows, this.nColumns);
        this.kkBuffer1 = new DenseMatrix64F(this.nColumns, this.nColumns);
        this.kkBuffer2 = new DenseMatrix64F(this.nColumns, this.nColumns);
        this.kkBuffer3 = new DenseMatrix64F(this.nColumns, this.nColumns);
        this.mmBuffer = new DenseMatrix64F(this.nRows, this.nRows);
        this.mBuffer1 = new DenseMatrix64F(this.nRows, 1);
        this.mBuffer2 = new DenseMatrix64F(this.nRows, 1);
        this.V = new DenseMatrix64F(factorAnalysisOperatorAdaptor.getNumberOfTraits(), factorAnalysisOperatorAdaptor.getNumberOfFactors());
        this.D = new DenseMatrix64F(factorAnalysisOperatorAdaptor.getNumberOfFactors(), factorAnalysisOperatorAdaptor.getNumberOfFactors());
        this.F = new DenseMatrix64F(factorAnalysisOperatorAdaptor.getNumberOfTaxa(), factorAnalysisOperatorAdaptor.getNumberOfFactors());
        this.Y = new DenseMatrix64F(factorAnalysisOperatorAdaptor.getNumberOfTaxa(), factorAnalysisOperatorAdaptor.getNumberOfTraits());
        this.H = new DenseMatrix64F(this.nRows, this.nColumns);
    }

    public double[] nextRandom() {
        this.updateC();
        return this.nextRandomNoUpdate();
    }

    public double[] nextRandomNoUpdate() {
        int n;
        int n2;
        SingularValueDecomposition<DenseMatrix64F> singularValueDecomposition = DecompositionFactory.svd(this.C.numRows, this.C.numCols, true, true, true);
        singularValueDecomposition.decompose(this.C);
        double[] dArray = singularValueDecomposition.getSingularValues();
        singularValueDecomposition.getU(this.H, false);
        for (n2 = 0; n2 < this.nRows; ++n2) {
            for (n = 0; n < this.nColumns; ++n) {
                this.H.set(n2, n, this.H.get(n2, n) * dArray[n]);
            }
        }
        n = 0;
        for (n2 = 0; n == 0 && n2 < 100; ++n2) {
            double d = MathUtils.nextDouble();
            DenseMatrix64F denseMatrix64F = this.transferColumns(this.H, 0, 1);
            double[] dArray2 = this.nextVectorVonMisesFisher(denseMatrix64F.data);
            denseMatrix64F.setData(dArray2);
            this.transferColumns(denseMatrix64F, 0, this.mkBuffer1, 0, 1);
            double d2 = 1.0;
            double d3 = (this.C.numRows - this.C.numCols - 1) / 2;
            for (int i = 1; i < this.C.numCols; ++i) {
                DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.C.numRows, i);
                this.transferColumns(this.mkBuffer1, 0, denseMatrix64F2, 0, i);
                SingularValueDecomposition<DenseMatrix64F> singularValueDecomposition2 = DecompositionFactory.svd(this.C.numRows, i, true, false, false);
                singularValueDecomposition2.decompose(denseMatrix64F2);
                singularValueDecomposition2.getU(this.mmBuffer, false);
                DenseMatrix64F denseMatrix64F3 = this.transferColumns(this.mmBuffer, i, this.C.numRows - i);
                this.transferColumns(this.H, i, this.mBuffer1, 0, 1);
                double d4 = this.computeNorm(this.mBuffer1.data);
                DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(this.C.numRows - i, 1);
                CommonOps.multTransA(denseMatrix64F3, this.mBuffer1, denseMatrix64F4);
                double d5 = this.computeNorm(denseMatrix64F4.data);
                double[] dArray3 = this.nextVectorVonMisesFisher(denseMatrix64F4.data);
                denseMatrix64F4.setData(dArray3);
                CommonOps.mult(denseMatrix64F3, denseMatrix64F4, this.mBuffer1);
                this.transferColumns(this.mBuffer1, 0, this.mkBuffer1, i, 1);
                d2 *= ModifiedBesselFirstKind.scaledBessIRatio(d5, d4, d3);
            }
            if (!(d < d2)) continue;
            n = 1;
        }
        if (n == 0) {
            System.err.println("Didn't work.");
            return null;
        }
        singularValueDecomposition.getV(this.kkBuffer1, true);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.C.numRows, this.C.numCols);
        CommonOps.multTransB(this.mkBuffer1, this.kkBuffer1, denseMatrix64F);
        return denseMatrix64F.data;
    }

    public void transferColumns(DenseMatrix64F denseMatrix64F, int n, DenseMatrix64F denseMatrix64F2, int n2, int n3) {
        for (int i = 0; i < denseMatrix64F.numRows; ++i) {
            for (int j = 0; j < n3; ++j) {
                denseMatrix64F2.set(i, j + n2, denseMatrix64F.get(i, j + n));
            }
        }
    }

    public DenseMatrix64F transferColumns(DenseMatrix64F denseMatrix64F, int n, int n2) {
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(denseMatrix64F.numRows, n2);
        this.transferColumns(denseMatrix64F, n, denseMatrix64F2, 0, n2);
        return denseMatrix64F2;
    }

    private double[] nextVectorVonMisesFisher(double[] dArray) {
        double[] dArray2 = (double[])dArray.clone();
        double d = this.makeUnit(dArray2);
        double[] dArray3 = this.nextVectorVonMisesFisherUnitMode(dArray.length, d);
        SingularValueDecomposition<DenseMatrix64F> singularValueDecomposition = DecompositionFactory.svd(dArray.length, 1, true, false, false);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(dArray.length, 1);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(dArray.length, dArray.length);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(dArray.length, 1);
        denseMatrix64F.setData(dArray2);
        singularValueDecomposition.decompose(denseMatrix64F);
        singularValueDecomposition.getU(denseMatrix64F2, false);
        for (int i = 0; i < dArray2.length; ++i) {
            double d2 = denseMatrix64F2.get(i, 0);
            double d3 = denseMatrix64F2.get(i, dArray2.length - 1);
            denseMatrix64F2.set(i, 0, d3);
            denseMatrix64F2.set(i, dArray2.length - 1, d2);
        }
        denseMatrix64F.setData(dArray3);
        CommonOps.mult(denseMatrix64F2, denseMatrix64F, denseMatrix64F3);
        return denseMatrix64F3.getData();
    }

    private double[] nextVectorVonMisesFisherUnitMode(int n, double d) {
        for (int i = 0; i < 100; ++i) {
            double d2 = n - 1;
            double d3 = -2.0 * d + Math.sqrt(4.0 * d * d + d2 * d2);
            double d4 = (1.0 - (d3 /= d2)) / (1.0 + d3);
            double d5 = d * d4 + d2 * Math.log(1.0 - d4 * d4);
            double d6 = MathUtils.nextBeta(d2 / 2.0, d2 / 2.0);
            double d7 = MathUtils.nextDouble();
            double d8 = (1.0 - (1.0 + d3) * d6) / (1.0 - (1.0 - d3) * d6);
            if (!(d * d8 + d2 * Math.log(1.0 - d4 * d8) - d5 > Math.log(d7))) continue;
            double[] dArray = this.nextUniformVector(n - 1);
            double[] dArray2 = new double[n];
            double d9 = Math.sqrt(1.0 - d8 * d8);
            for (int j = 0; j < dArray.length; ++j) {
                dArray2[j] = d9 * dArray[j];
            }
            dArray2[n - 1] = d8;
            return dArray2;
        }
        return null;
    }

    private double[] nextUniformVector(int n) {
        double[] dArray = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray[i] = MathUtils.nextGaussian();
        }
        this.makeUnit(dArray);
        return dArray;
    }

    private double makeUnit(double[] dArray) {
        double d = this.computeNorm(dArray);
        double d2 = 1.0 / d;
        int n = 0;
        while (n < dArray.length) {
            int n2 = n++;
            dArray[n2] = dArray[n2] * d2;
        }
        return d;
    }

    private double computeNorm(double[] dArray) {
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            d += dArray[i] * dArray[i];
        }
        return Math.sqrt(d);
    }

    private double[] slowNextRandom() {
        this.updateC();
        for (int i = 0; i < 100; ++i) {
            DenseMatrix64F denseMatrix64F = this.nextUniform();
            new DecompositionFactory();
            SingularValueDecomposition<DenseMatrix64F> singularValueDecomposition = DecompositionFactory.svd(this.C.numRows, this.C.numCols, false, false, true);
            singularValueDecomposition.decompose(this.C);
            double[] dArray = singularValueDecomposition.getSingularValues();
            System.out.println("Rejects: " + i);
            CommonOps.multTransA(this.C, denseMatrix64F, this.kkBuffer1);
            double d = 0.0;
            for (int j = 0; j < this.C.numCols; ++j) {
                d += this.kkBuffer1.get(j, j) - dArray[j];
            }
            System.out.println("ExpTrace: " + Math.exp(d));
            if (MathUtils.nextDouble() < Math.exp(d)) {
                return denseMatrix64F.getData();
            }
            System.out.println("");
        }
        throw new RuntimeException("Rejection sampler failed.");
    }

    private DenseMatrix64F nextUniform() {
        double[] dArray = new double[this.C.getNumElements()];
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = MathUtils.nextGaussian();
        }
        this.mkBuffer1.setData(dArray);
        CommonOps.multTransA(this.mkBuffer1, this.mkBuffer1, this.kkBuffer1);
        EigenDecomposition<DenseMatrix64F> eigenDecomposition = DecompositionFactory.eig(this.nColumns, true, true);
        eigenDecomposition.decompose(this.kkBuffer1);
        for (int i = 0; i < this.nColumns; ++i) {
            DenseMatrix64F denseMatrix64F = eigenDecomposition.getEigenVector(i);
            for (int j = 0; j < this.nColumns; ++j) {
                this.kkBuffer2.set(j, i, denseMatrix64F.get(j, 0));
            }
            double d = eigenDecomposition.getEigenvalue(i).getReal();
            this.kkBuffer3.set(i, i, 1.0 / Math.sqrt(d));
        }
        CommonOps.mult(this.kkBuffer2, this.kkBuffer3, this.kkBuffer1);
        CommonOps.multTransB(this.kkBuffer1, this.kkBuffer2, this.kkBuffer3);
        CommonOps.mult(this.mkBuffer1, this.kkBuffer3, this.mkBuffer2);
        return this.mkBuffer2;
    }

    @Override
    public double logPdf(Object object) {
        throw new RuntimeException("Not yet implemented.");
    }

    public void setC(double[] dArray) {
        this.C.setData(dArray);
    }

    public void updateC() {
        this.splitLoadings();
        this.fillFactors();
        this.fillTraits();
        double d = this.getMaximumPrecision();
        CommonOps.multTransA(this.Y, this.F, this.C);
        for (int i = 0; i < this.adaptor.getNumberOfFactors(); ++i) {
            double d2 = this.D.get(i, i) * d;
            for (int j = 0; j < this.adaptor.getNumberOfTraits(); ++j) {
                this.C.set(j, i, this.C.get(j, i) * d2);
            }
        }
    }

    private void splitLoadings() {
        int n = 0;
        int n2 = 0;
        for (int i = 0; i < this.adaptor.getNumberOfFactors(); ++i) {
            double d = 0.0;
            for (int j = 0; j < this.adaptor.getNumberOfTraits(); ++j) {
                d += this.adaptor.getLoadingsValue(n) * this.adaptor.getLoadingsValue(n);
                ++n;
            }
            double d2 = Math.sqrt(d);
            this.D.set(i, i, d2);
            double d3 = 1.0 / d2;
            for (int j = 0; j < this.adaptor.getNumberOfTraits(); ++j) {
                this.V.set(n2, this.adaptor.getLoadingsValue(n2) * d3);
                ++n2;
            }
        }
    }

    private void fillFactors() {
        this.adaptor.drawFactors();
        for (int i = 0; i < this.adaptor.getNumberOfTaxa(); ++i) {
            for (int j = 0; j < this.adaptor.getNumberOfFactors(); ++j) {
                this.F.set(i, j, this.adaptor.getFactorValue(j, i));
            }
        }
    }

    private void fillTraits() {
        for (int i = 0; i < this.adaptor.getNumberOfTaxa(); ++i) {
            for (int j = 0; j < this.adaptor.getNumberOfTraits(); ++j) {
                this.Y.set(i, j, this.adaptor.getDataValue(j, i));
            }
        }
    }

    private double getMaximumPrecision() {
        double d = 0.0;
        for (int i = 0; i < this.adaptor.getNumberOfTraits(); ++i) {
            if (!(this.adaptor.getColumnPrecision(i) > d)) continue;
            d = this.adaptor.getColumnPrecision(i);
        }
        return d;
    }

    @Override
    public double logPdf(double[] dArray) {
        this.updateC();
        this.mkBuffer1.setData(dArray);
        CommonOps.multTransA(this.C, this.mkBuffer1, this.kkBuffer1);
        double d = 0.0;
        for (int i = 0; i < this.kkBuffer1.numCols; ++i) {
            d += this.kkBuffer1.get(i, i);
        }
        return d;
    }

    @Override
    public double[][] getScaleMatrix() {
        throw new RuntimeException("Not yet implemented.");
    }

    @Override
    public double[] getMean() {
        throw new RuntimeException("Not yet implemented.");
    }

    @Override
    public String getType() {
        return "MatrixVonMises-Fisher";
    }

    @Override
    public String getReport() {
        Object object;
        int n;
        int n2 = 100;
        int n3 = this.C.numRows;
        double[] dArray = this.nextUniformVector(n3);
        double d = 1000.0;
        int n4 = 0;
        while (n4 < n3) {
            int n5 = n4++;
            dArray[n5] = dArray[n5] * d;
        }
        double[] dArray2 = new double[n3];
        double[] dArray3 = new double[n3];
        for (n = 0; n < n2; ++n) {
            object = this.nextVectorVonMisesFisher(dArray);
            for (int i = 0; i < n3; ++i) {
                int n6 = i;
                dArray2[n6] = dArray2[n6] + object[i];
                int n7 = i;
                dArray3[n7] = dArray3[n7] + object[i] * object[i];
            }
            double d2 = this.makeUnit((double[])object);
            if (MathUtils.isClose(d2, 1.0, 1.0E-8)) continue;
            System.err.println("Norm: " + d2);
        }
        for (n = 0; n < n3; ++n) {
            int n8 = n;
            dArray2[n8] = dArray2[n8] / (double)n2;
            int n9 = n;
            dArray3[n9] = dArray3[n9] / (double)n2;
            int n10 = n;
            dArray3[n10] = dArray3[n10] - dArray2[n] * dArray2[n];
        }
        StringBuilder stringBuilder = new StringBuilder("matrix von Mises-Fisher distribution:\n");
        this.makeUnit(dArray);
        stringBuilder.append("original: " + new Vector(dArray) + "\n");
        stringBuilder.append("mean: " + new Vector(dArray2));
        stringBuilder.append("\n");
        stringBuilder.append("variance: " + new Vector(dArray3));
        stringBuilder.append("\n\n");
        object = this.nextUniform();
        for (int i = 0; i < object.data.length; ++i) {
            this.C.data[i] = object.data[i] * d;
        }
        double[] dArray4 = this.nextRandomNoUpdate();
        stringBuilder.append(new Vector(dArray4));
        stringBuilder.append("\n");
        stringBuilder.append(new Vector(object.data));
        return stringBuilder.toString();
    }
}

