/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.core.statisticaltests;

import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import moa.classifiers.core.statisticaltests.Cramer;
import moa.classifiers.core.statisticaltests.StatisticalTest;
import moa.core.ObjectRepository;
import moa.options.AbstractOptionHandler;
import moa.tasks.TaskMonitor;

public class KNN
extends AbstractOptionHandler
implements StatisticalTest {
    private List<Instance> sample1i;
    private List<Instance> sample2i;
    public IntOption kValueOption = new IntOption("kValue", 'k', "K value of the K nearest neighbours algorithm.", 5, 1, Integer.MAX_VALUE);

    private double[] compute(double[][] set, int d, int n1, int n2) throws InterruptedException {
        double n = n1 + n2;
        Arrays.fill(set[d], 0, n1, 1.0);
        Arrays.fill(set[d], n1, n1 + n2, 2.0);
        int[] counts = this.knn(set, n1 + n2, d, this.kValueOption.getValue());
        double Tk = 0.0;
        for (int i = 0; i < counts.length; ++i) {
            Tk += (double)counts[i];
        }
        double V = (double)((n1 - 1) * (n2 - 1)) / ((n - 1.0) * (n - 1.0)) + 4.0 * ((double)((n1 - 1) * (n1 - 2)) / ((n - 1.0) * (n - 2.0))) * ((double)((n2 - 1) * (n2 - 2)) / ((n - 1.0) * (n - 2.0)));
        double Z = Math.sqrt(n * (double)this.kValueOption.getValue()) * ((Tk /= n * (double)this.kValueOption.getValue()) - (double)((n1 - 1) * (n1 - 2)) / ((n - 1.0) * (n - 2.0)) - (double)((n2 - 1) * (n2 - 2)) / ((n - 1.0) * (n - 2.0))) / Math.sqrt(V);
        double P = this.pnorm(Z, 0.0, 1.0, false, false);
        return new double[]{Tk, Z, P};
    }

    private double[] attributeToDoubleArray(List<Instance> list, int attIndex) {
        double[] ret = new double[list.size()];
        for (int i = 0; i < list.size(); ++i) {
            ret[i] = list.get(i).value(attIndex);
        }
        return ret;
    }

    public double[] mtsknn(List<Instance> x, List<Instance> y) throws InterruptedException {
        if (x.get(0).numAttributes() != y.get(0).numAttributes()) {
            System.out.println("The dimensions of two samples must match!!!");
            return null;
        }
        int d = x.get(0).numAttributes() - 1;
        int n1 = x.size();
        int n2 = y.size();
        double[][] set = new double[d + 1][n1 + n2];
        for (int i = 0; i < d; ++i) {
            double[] t1 = this.attributeToDoubleArray(x, i);
            double[] t2 = this.attributeToDoubleArray(y, i);
            System.arraycopy(t1, 0, set[i], 0, t1.length);
            System.arraycopy(t2, 0, set[i], t1.length, t2.length);
        }
        return this.compute(set, d, n1, n2);
    }

    private double pnorm(double x, double mu, double sigma, boolean lower_tail, boolean log_p) {
        double p;
        if (Double.isNaN(x) || Double.isNaN(mu) || Double.isNaN(sigma)) {
            return x + mu + sigma;
        }
        if (Double.isInfinite(x) && mu == x) {
            return Double.NaN;
        }
        if (sigma <= 0.0 && x < mu) {
            this.R_DT(lower_tail, log_p);
        }
        if (Double.isInfinite(p = (x - mu) / sigma) && x < mu) {
            this.R_DT(lower_tail, log_p);
        }
        x = p;
        double[] ret = this.pnorm_both(x, p, lower_tail ? 0 : 1, log_p);
        return lower_tail ? ret[0] : ret[1];
    }

    private double dist(double[][] points, int v1, int v2, int d) {
        double sum = 0.0;
        for (int i = 0; i != d; ++i) {
            sum += (points[i][v1] - points[i][v2]) * (points[i][v1] - points[i][v2]);
        }
        return sum;
    }

    private int[] knn(double[][] points, int n, int d, int k) throws InterruptedException {
        int i;
        int[] counts = new int[n];
        int[] closest = new int[n * k];
        for (i = 0; i != n; ++i) {
            int j;
            if (Thread.interrupted()) {
                throw new InterruptedException();
            }
            PriorityQueue<DIPair> q = new PriorityQueue<DIPair>(k, new HigherComparator());
            for (j = 0; j != n; ++j) {
                if (i == j) continue;
                DIPair dis = new DIPair(this.dist(points, i, j, d), j);
                if (q.size() == k) {
                    if (!(dis.getE() < q.peek().getE())) continue;
                    q.add(dis);
                    q.poll();
                    continue;
                }
                q.add(dis);
            }
            for (j = 0; j != k; ++j) {
                closest[i * k + j] = q.poll().getI();
            }
        }
        for (i = 0; i != n; ++i) {
            if (Thread.interrupted()) {
                throw new InterruptedException();
            }
            for (int j = 0; j != k; ++j) {
                if (points[d][closest[i * k + j]] != points[d][i]) continue;
                int n2 = i;
                counts[n2] = counts[n2] + 1;
            }
        }
        return counts;
    }

    private double R_DT(boolean lower_tail, boolean log_p) {
        return lower_tail ? (log_p ? Double.NEGATIVE_INFINITY : 0.0) : (double)(!log_p ? 1 : 0);
    }

    private double[] pnorm_both(double x, double cum, int i_tail, boolean log_p) {
        double ccum = 0.0;
        double[] a = new double[]{2.2352520354606837, 161.02823106855587, 1067.6894854603709, 18154.98125334356, 0.06568233791820745};
        double[] b = new double[]{47.202581904688245, 976.0985517377767, 10260.932208618979, 45507.78933502673};
        double[] c = new double[]{0.39894151208813466, 8.883149794388377, 93.50665613217785, 597.2702763948002, 2494.5375852903726, 6848.190450536283, 11602.65143764735, 9842.714838383978, 1.0765576773720192E-8};
        double[] d = new double[]{22.266688044328117, 235.387901782625, 1519.3775994075547, 6485.558298266761, 18615.571640885097, 34900.95272114598, 38912.00328609327, 19685.429676859992};
        double[] p = new double[]{0.215898534057957, 0.12740116116024736, 0.022235277870649807, 0.0014216191932278934, 2.9112874951168793E-5, 0.023073441764940174};
        double[] q = new double[]{1.284260096144911, 0.4682382124808651, 0.06598813786892856, 0.0037823963320275824, 7.297515550839662E-5};
        double M_SQRT_32 = 5.656854249492381;
        double M_1_SQRT_2PI = 0.3989422804014327;
        double min = Double.MIN_VALUE;
        if (Double.isNaN(x)) {
            cum = ccum = x;
            return new double[]{cum, ccum};
        }
        double eps = 5.0E-10;
        boolean lower = i_tail != 1;
        boolean upper = i_tail != 0;
        double y = Math.abs(x);
        if (y <= 0.67448975) {
            double xden;
            double xnum;
            if (y > eps) {
                double xsq = x * x;
                xnum = a[4] * xsq;
                xden = xsq;
                for (int i = 0; i < 3; ++i) {
                    xnum = (xnum + a[i]) * xsq;
                    xden = (xden + b[i]) * xsq;
                }
            } else {
                xden = 0.0;
                xnum = 0.0;
            }
            double temp = x * (xnum + a[3]) / (xden + b[3]);
            if (lower) {
                cum = 0.5 + temp;
            }
            if (upper) {
                ccum = 0.5 - temp;
            }
            if (log_p) {
                if (lower) {
                    cum = Math.log(cum);
                }
                if (upper) {
                    ccum = Math.log(ccum);
                }
            }
        } else if (y <= 5.656854249492381) {
            double xnum = c[8] * y;
            double xden = y;
            for (int i = 0; i < 7; ++i) {
                xnum = (xnum + c[i]) * y;
                xden = (xden + d[i]) * y;
            }
            double temp = (xnum + c[7]) / (xden + d[7]);
            double[] retorno = this.do_del(y, log_p, cum, ccum, lower, x, temp, upper);
            retorno = this.swap_tail(x, temp, retorno[0], lower, retorno[1]);
            cum = retorno[0];
            ccum = retorno[1];
        } else if (log_p || lower && -37.5193 < x && x < 8.2924 || upper && -8.2924 < x && x < 37.5193) {
            double xsq = 1.0 / (x * x);
            double xnum = p[5] * xsq;
            double xden = xsq;
            for (int i = 0; i < 4; ++i) {
                xnum = (xnum + p[i]) * xsq;
                xden = (xden + q[i]) * xsq;
            }
            double temp = xsq * (xnum + p[4]) / (xden + q[4]);
            temp = (0.3989422804014327 - temp) / y;
            double[] retorno = this.do_del(x, log_p, cum, ccum, lower, x, temp, upper);
            retorno = this.swap_tail(x, temp, retorno[0], lower, retorno[1]);
            cum = retorno[0];
            ccum = retorno[1];
        } else if (x > 0.0) {
            cum = 1.0;
            ccum = 0.0;
        } else {
            cum = 0.0;
            ccum = 1.0;
        }
        if (log_p) {
            if (cum > -min) {
                cum = -0.0;
            }
            if (ccum > -min) {
                ccum = -0.0;
            }
        } else {
            if (cum < min) {
                cum = 0.0;
            }
            if (ccum < min) {
                ccum = 0.0;
            }
        }
        return new double[]{cum, ccum};
    }

    private double[] do_del(double X, boolean log_p, double cum, double ccum, boolean lower, double x, double temp, boolean upper) {
        int SIXTEN = 16;
        double xsq = Math.ceil(X * 16.0) / 16.0;
        double del = (X - xsq) * (X + xsq);
        if (log_p) {
            cum = -xsq * xsq * 0.5 + -del * 0.5 + Math.log(temp);
            if (lower && x > 0.0 || upper && x <= 0.0) {
                ccum = Math.log1p(-Math.exp(-xsq * xsq * 0.5) * Math.exp(-del * 0.5) * temp);
            }
        } else {
            cum = Math.exp(-xsq * xsq * 0.5) * Math.exp(-del * 0.5) * temp;
            ccum = 1.0 - cum;
        }
        return new double[]{cum, ccum};
    }

    private double[] swap_tail(double x, double temp, double cum, boolean lower, double ccum) {
        if (x > 0.0) {
            temp = cum;
            if (lower) {
                cum = ccum;
            }
            ccum = temp;
        }
        return new double[]{cum, ccum};
    }

    @Override
    public double test(List<Instance> x, List<Instance> y) {
        try {
            return this.mtsknn(x, y)[2];
        }
        catch (InterruptedException ie) {
            return 0.0;
        }
    }

    @Override
    public void getDescription(StringBuilder sb, int indent) {
    }

    @Override
    protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
    }

    @Override
    public Double call() throws Exception {
        return this.test(this.sample1i, this.sample2i);
    }

    @Override
    public void set(List<Instance> x, List<Instance> y) {
        this.sample1i = x;
        this.sample2i = y;
    }

    public static void main(String[] args) throws Exception {
        List<Instance> x = Cramer.fileToInstances("c:\\Users\\Paulo\\Documents\\test1-x.arff");
        List<Instance> y = Cramer.fileToInstances("c:\\Users\\Paulo\\Documents\\test1-y.arff");
        KNN c = new KNN();
        double[] ct = c.mtsknn(x, y);
        System.out.println("p Value [Resultado esperado: 0.09866699171730517] [Resultado obtido..: " + ct[2] + "]");
        System.out.println("Critical value [Resultado esperado: 0.521] [Resultado obtido: " + ct[0] + "]");
        System.out.println("Statistic [Resultado esperado: 1.2891844104764096] [Resultado obtido: " + ct[1] + "]");
    }

    private class HigherComparator
    implements Comparator<DIPair> {
        private HigherComparator() {
        }

        @Override
        public int compare(DIPair o1, DIPair o2) {
            return o1.e > o2.e ? -1 : (o1.e == o2.e ? 0 : 1);
        }
    }

    private class DIPair {
        double e = 0.0;
        int i = 0;

        public DIPair(double e0, int i0) {
            this.e = e0;
            this.i = i0;
        }

        public double getE() {
            return this.e;
        }

        public int getI() {
            return this.i;
        }
    }
}

