/*
 * Decompiled with CFR 0.152.
 */
package bartMachine;

import OpenSourceExtensions.UnorderedPair;
import bartMachine.Classifier;
import bartMachine.StatToolbox;
import bartMachine.Tools;
import bartMachine.bartMachineClassification;
import bartMachine.bartMachineRegression;
import bartMachine.bartMachineTreeNode;
import bartMachine.bartMachine_b_hyperparams;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorSpecies;

public class bartMachineRegressionMultThread
extends Classifier
implements Serializable {
    private static final int PARALLEL_SORT_THRESHOLD = 16384;
    private static final ThreadLocal<double[]> EVAL_BUFFER = ThreadLocal.withInitial(() -> new double[0]);
    protected int num_cores = 1;
    protected int num_trees = 50;
    protected bartMachineRegression[] bart_gibbs_chain_threads;
    protected bartMachineTreeNode[][] gibbs_samples_of_bart_trees_after_burn_in;
    private Double sample_var_y;
    protected int num_gibbs_burn_in = 250;
    protected int num_gibbs_total_iterations = 1250;
    protected int total_iterations_multithreaded;
    protected double[] cov_split_prior;
    protected Double alpha = 0.95;
    protected Double beta = 2.0;
    protected Double hyper_k = 2.0;
    protected Double hyper_q = 0.9;
    protected Double hyper_nu = 3.0;
    protected Double prob_grow = 0.2777777777777778;
    protected Double prob_prune = 0.2777777777777778;
    protected boolean verbose = true;
    protected boolean mem_cache_for_speed = true;
    protected boolean flush_indices_to_save_ram = true;
    private boolean tree_illust;
    private HashMap<Integer, IntOpenHashSet> interaction_constraints;
    protected Integer seed;
    protected boolean use_xoshiro;

    public bartMachineRegressionMultThread() {
        this.setNumGibbsTotalIterations(this.num_gibbs_total_iterations);
    }

    public void setNumGibbsTotalIterations(int n) {
        this.num_gibbs_total_iterations = n;
        this.total_iterations_multithreaded = this.num_gibbs_burn_in + (int)Math.ceil((double)(n - this.num_gibbs_burn_in) / (double)this.num_cores);
    }

    public int numSamplesAfterBurning() {
        return this.num_gibbs_total_iterations - this.num_gibbs_burn_in;
    }

    protected void SetupBARTModels() {
        this.bart_gibbs_chain_threads = new bartMachineRegression[this.num_cores];
        for (int i = 0; i < this.num_cores; ++i) {
            bartMachineRegression bartMachineRegression2 = new bartMachineRegression();
            this.SetupBartModel(bartMachineRegression2, i);
        }
    }

    protected void SetupBartModel(bartMachineRegression bartMachineRegression2, int n) {
        bartMachineRegression2.setVerbose(this.verbose);
        bartMachineRegression2.num_trees = this.num_trees;
        bartMachineRegression2.num_gibbs_total_iterations = this.total_iterations_multithreaded;
        bartMachineRegression2.num_gibbs_burn_in = this.num_gibbs_burn_in;
        bartMachineRegression2.sample_var_y = this.sample_var_y;
        bartMachineRegression2.setAlpha(this.alpha);
        bartMachineRegression2.setBeta(this.beta);
        bartMachineRegression2.setK(this.hyper_k);
        bartMachineRegression2.setProbGrow(this.prob_grow);
        bartMachineRegression2.setProbPrune(this.prob_prune);
        bartMachineRegression2.setThreadNum(n);
        if (this.seed != null) {
            bartMachineRegression2.setSeed(this.seed + n);
        }
        bartMachineRegression2.setTotalNumThreads(this.num_cores);
        bartMachineRegression2.setMemCacheForSpeed(this.mem_cache_for_speed);
        bartMachineRegression2.setFlushIndicesToSaveRAM(this.flush_indices_to_save_ram);
        bartMachineRegression2.setUseXoshiro(this.use_xoshiro);
        if (this.cov_split_prior != null) {
            bartMachineRegression2.setCovSplitPrior(this.cov_split_prior);
        }
        if (this.interaction_constraints != null) {
            bartMachineRegression2.setInteractionConstraints(this.interaction_constraints);
        }
        if (!(bartMachineRegression2 instanceof bartMachineClassification)) {
            bartMachineRegression2.setNu(this.hyper_nu);
            bartMachineRegression2.setQ(this.hyper_q);
        }
        bartMachineRegression2.setData(this.X_y);
        bartMachineRegression2.tree_illust = this.tree_illust;
        this.bart_gibbs_chain_threads[n] = bartMachineRegression2;
    }

    public void setNormSamples(double[] dArray) {
        bartMachine_b_hyperparams.samps_std_normal = dArray;
        bartMachine_b_hyperparams.samps_std_normal_length = dArray.length;
    }

    public void setGammaSamples(double[] dArray) {
        bartMachine_b_hyperparams.samps_chi_sq_df_eq_nu_plus_n = dArray;
        bartMachine_b_hyperparams.samps_chi_sq_df_eq_nu_plus_n_length = dArray.length;
    }

    @Override
    public void Build() {
        this.SetupBARTModels();
        long l = System.currentTimeMillis();
        if (this.verbose) {
            System.out.println("building BART " + (this.mem_cache_for_speed ? "with" : "without") + " mem-cache speedup...");
        }
        this.BuildOnAllThreads();
        long l2 = System.currentTimeMillis();
        if (this.verbose) {
            System.out.println("done building BART in " + (double)(l2 - l) / 1000.0 + " sec \n");
        }
        this.ConstructBurnedChainForTreesAndOtherInformation();
    }

    protected void ConstructBurnedChainForTreesAndOtherInformation() {
        this.gibbs_samples_of_bart_trees_after_burn_in = new bartMachineTreeNode[this.numSamplesAfterBurning()][this.num_trees];
        if (this.verbose) {
            System.out.print("burning and aggregating chains from all threads... ");
        }
        for (int i = 0; i < this.num_cores; ++i) {
            int n;
            int n2;
            bartMachineRegression bartMachineRegression2 = this.bart_gibbs_chain_threads[i];
            for (int j = this.num_gibbs_burn_in; j < this.total_iterations_multithreaded && (n2 = (n = i * (this.total_iterations_multithreaded - this.num_gibbs_burn_in)) + (j - this.num_gibbs_burn_in)) < this.numSamplesAfterBurning(); ++j) {
                this.gibbs_samples_of_bart_trees_after_burn_in[n2] = bartMachineRegression2.gibbs_samples_of_bart_trees[j];
            }
        }
        if (this.verbose) {
            System.out.print("done\n");
        }
    }

    private void BuildOnAllThreads() {
        try (ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();){
            int n = 0;
            while (n < this.num_cores) {
                int n2 = n++;
                executorService.execute(() -> this.bart_gibbs_chain_threads[n2].Build());
            }
        }
        catch (Exception exception) {
            throw new RuntimeException(exception);
        }
    }

    protected boolean[][][][] getNodePredictionTrainingIndicies(double[][] dArray) {
        int n;
        if (dArray == null) {
            dArray = new double[this.n][this.p];
            for (n = 0; n < this.n; ++n) {
                dArray[n] = (double[])this.X_y.get(n);
            }
        }
        n = dArray.length;
        int n2 = this.numSamplesAfterBurning();
        boolean[][][][] blArray = new boolean[n][n2][this.num_trees][this.n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n2; ++j) {
                bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[j];
                for (int k = 0; k < this.num_trees; ++k) {
                    for (int n3 : bartMachineTreeNodeArray[k].EvaluateNode((double[])dArray[i]).indicies) {
                        blArray[i][j][k][n3] = true;
                    }
                }
            }
        }
        return blArray;
    }

    protected double[][] getProjectionWeights(double[][] dArray) {
        int n;
        if (dArray == null) {
            dArray = new double[this.n][this.p];
            for (n = 0; n < this.n; ++n) {
                dArray[n] = (double[])this.X_y.get(n);
            }
        }
        n = dArray.length;
        int n2 = this.numSamplesAfterBurning();
        boolean[][][][] blArray = this.getNodePredictionTrainingIndicies(dArray);
        double[][] dArrayArray = new double[n][];
        for (int i = 0; i < n; ++i) {
            int n3;
            double[] dArray2 = new double[this.n];
            for (n3 = 0; n3 < n2; ++n3) {
                for (int j = 0; j < this.num_trees; ++j) {
                    boolean[] blArray2 = blArray[i][n3][j];
                    int n4 = Tools.sum_array(blArray2);
                    for (int k = 0; k < this.n; ++k) {
                        int n5 = k;
                        dArray2[n5] = dArray2[n5] + (double)(blArray2[k] ? 1 : 0) / ((double)n4 * (double)this.num_trees);
                    }
                }
            }
            n3 = 0;
            while (n3 < this.n) {
                int n6 = n3++;
                dArray2[n6] = dArray2[n6] * (1.0 / (double)n2);
            }
            dArrayArray[i] = dArray2;
        }
        return dArrayArray;
    }

    protected double[][] getGibbsSamplesForPrediction(double[][] dArray, int n) {
        int n2;
        int n3 = this.numSamplesAfterBurning();
        bartMachineRegression bartMachineRegression2 = this.bart_gibbs_chain_threads[0];
        int n4 = dArray.length;
        double[][] dArray2 = new double[n4][n3];
        int[] nArray = new int[n4];
        for (n2 = 0; n2 < n4; ++n2) {
            nArray[n2] = n2;
        }
        if (n == 1) {
            double[] dArray3 = bartMachineRegressionMultThread.getThreadLocalBuffer(n4);
            for (int i = 0; i < n3; ++i) {
                int n5;
                Arrays.fill(dArray3, 0.0);
                bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i];
                for (n5 = 0; n5 < this.num_trees; ++n5) {
                    bartMachineTreeNodeArray[n5].evaluateBatch(dArray, nArray, dArray3);
                }
                bartMachineRegression2.un_transform_y_batch(dArray3, dArray3);
                for (n5 = 0; n5 < n4; ++n5) {
                    dArray2[n5][i] = dArray3[n5];
                }
            }
        } else {
            n2 = Math.max(1, n3 / n);
            ArrayList<Future<Object>> arrayList = new ArrayList<Future<Object>>();
            try (ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();){
                for (int i = 0; i < n3; i += n2) {
                    int n6 = i;
                    int n7 = Math.min(n3, i + n2);
                    arrayList.add(executorService.submit(() -> {
                        double[] dArray3 = new double[n4];
                        for (int i = n6; i < n7; ++i) {
                            int n4;
                            Arrays.fill(dArray3, 0.0);
                            int[] nArray2 = (int[])nArray.clone();
                            bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i];
                            for (n4 = 0; n4 < this.num_trees; ++n4) {
                                bartMachineTreeNodeArray[n4].evaluateBatch(dArray, nArray2, dArray3);
                            }
                            bartMachineRegression2.un_transform_y_batch(dArray3, dArray3);
                            for (n4 = 0; n4 < n4; ++n4) {
                                dArray2[n4][i] = dArray3[n4];
                            }
                        }
                        return null;
                    }));
                }
                for (Future future : arrayList) {
                    future.get();
                }
            }
            catch (Exception exception) {
                throw new RuntimeException(exception);
            }
        }
        return dArray2;
    }

    public double[] getPosteriorMeanForPrediction(double[][] dArray, int n) {
        int n2;
        int n3 = this.numSamplesAfterBurning();
        int n4 = dArray.length;
        double[] dArray2 = new double[n4];
        if (n4 == 0 || n3 <= 0) {
            return dArray2;
        }
        bartMachineRegression bartMachineRegression2 = this.bart_gibbs_chain_threads[0];
        int[] nArray = new int[n4];
        for (n2 = 0; n2 < n4; ++n2) {
            nArray[n2] = n2;
        }
        if (n == 1) {
            double[] dArray3 = bartMachineRegressionMultThread.getThreadLocalBuffer(n4);
            for (int i = 0; i < n3; ++i) {
                Arrays.fill(dArray3, 0.0);
                var10_13 = this.gibbs_samples_of_bart_trees_after_burn_in[i];
                for (int j = 0; j < this.num_trees; ++j) {
                    var10_13[j].evaluateBatch(dArray, nArray, dArray3);
                }
                bartMachineRegression2.un_transform_y_batch(dArray3, dArray3);
                bartMachineRegressionMultThread.addInPlace(dArray2, dArray3);
            }
        } else {
            n2 = Math.max(1, n3 / n);
            ArrayList<Future<double[]>> arrayList = new ArrayList<Future<double[]>>();
            try {
                var10_13 = Executors.newVirtualThreadPerTaskExecutor();
                try {
                    for (int i = 0; i < n3; i += n2) {
                        int n5 = i;
                        int n6 = Math.min(n3, i + n2);
                        arrayList.add(var10_13.submit(() -> {
                            double[] dArray2 = new double[n4];
                            double[] dArray3 = new double[n4];
                            for (int i = n5; i < n6; ++i) {
                                Arrays.fill(dArray3, 0.0);
                                int[] nArray2 = (int[])nArray.clone();
                                bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i];
                                for (int j = 0; j < this.num_trees; ++j) {
                                    bartMachineTreeNodeArray[j].evaluateBatch(dArray, nArray2, dArray3);
                                }
                                bartMachineRegression2.un_transform_y_batch(dArray3, dArray3);
                                bartMachineRegressionMultThread.addInPlace(dArray2, dArray3);
                            }
                            return dArray2;
                        }));
                    }
                    for (Future future : arrayList) {
                        double[] dArray4 = (double[])future.get();
                        bartMachineRegressionMultThread.addInPlace(dArray2, dArray4);
                    }
                }
                finally {
                    if (var10_13 != null) {
                        var10_13.close();
                    }
                }
            }
            catch (Exception exception) {
                throw new RuntimeException(exception);
            }
        }
        int n7 = 0;
        while (n7 < n4) {
            int n5 = n7++;
            dArray2[n5] = dArray2[n5] / (double)n3;
        }
        return dArray2;
    }

    protected double[] getPostPredictiveIntervalForPrediction(double[] dArray, double d, int n) {
        double[][] dArray2 = new double[1][dArray.length];
        dArray2[0] = dArray;
        double[][] dArray3 = this.getGibbsSamplesForPrediction(dArray2, n);
        double[] dArray4 = dArray3[0];
        bartMachineRegressionMultThread.sortInPlace(dArray4);
        double d2 = (1.0 - d) / 2.0;
        double d3 = 1.0 - d2;
        double[] dArray5 = new double[]{bartMachineRegressionMultThread.quantileType7Sorted(dArray4, d2), bartMachineRegressionMultThread.quantileType7Sorted(dArray4, d3)};
        return dArray5;
    }

    protected double[] get95PctPostPredictiveIntervalForPrediction(double[] dArray, int n) {
        return this.getPostPredictiveIntervalForPrediction(dArray, 0.95, n);
    }

    public double[][] getCredibleIntervalsForPrediction(double[][] dArray, double d, int n) {
        double[][] dArray2 = this.getGibbsSamplesForPrediction(dArray, n);
        double[][] dArray3 = new double[dArray2.length][2];
        double d2 = (1.0 - d) / 2.0;
        double d3 = 1.0 - d2;
        for (int i = 0; i < dArray2.length; ++i) {
            dArray3[i][0] = bartMachineRegressionMultThread.quantileType7(dArray2[i], d2);
            dArray3[i][1] = bartMachineRegressionMultThread.quantileType7(dArray2[i], d3);
        }
        return dArray3;
    }

    public double[][] getPredictionIntervalsForPrediction(double[][] dArray, double d, int n, int n2) {
        double[][] dArray2 = this.getGibbsSamplesForPrediction(dArray, n2);
        double[] dArray3 = this.getGibbsSamplesSigsqs();
        double[][] dArray4 = new double[dArray2.length][2];
        double d2 = (1.0 - d) / 2.0;
        double d3 = 1.0 - d2;
        if (dArray2.length == 0) {
            return dArray4;
        }
        int n3 = dArray2[0].length;
        int n4 = Math.min(n3, dArray3.length);
        if (n4 == 0 || n <= 0) {
            for (int i = 0; i < dArray4.length; ++i) {
                dArray4[i][0] = Double.NaN;
                dArray4[i][1] = Double.NaN;
            }
            return dArray4;
        }
        double[] dArray5 = new double[n];
        for (int i = 0; i < dArray2.length; ++i) {
            for (int j = 0; j < n; ++j) {
                int n5 = (int)Math.floor(StatToolbox.rand() * (double)n4);
                double d4 = dArray2[i][n5];
                double d5 = dArray3[n5];
                dArray5[j] = StatToolbox.sample_from_norm_dist(d4, d5);
            }
            dArray4[i][0] = bartMachineRegressionMultThread.quantileType7(dArray5, d2);
            dArray4[i][1] = bartMachineRegressionMultThread.quantileType7(dArray5, d3);
        }
        return dArray4;
    }

    private static double quantileType7Sorted(double[] dArray, double d) {
        int n = dArray.length;
        if (n == 0) {
            return Double.NaN;
        }
        if (n == 1) {
            return dArray[0];
        }
        if (d <= 0.0) {
            return dArray[0];
        }
        if (d >= 1.0) {
            return dArray[n - 1];
        }
        double d2 = (double)(n - 1) * d + 1.0;
        int n2 = (int)Math.floor(d2);
        double d3 = dArray[n2 - 1];
        if (d2 == (double)n2) {
            return d3;
        }
        double d4 = dArray[n2];
        return d3 + (d2 - (double)n2) * (d4 - d3);
    }

    private static double quantileType7(double[] dArray, double d) {
        int n = dArray.length;
        if (n == 0) {
            return Double.NaN;
        }
        if (n == 1) {
            return dArray[0];
        }
        if (d <= 0.0) {
            return bartMachineRegressionMultThread.min(dArray);
        }
        if (d >= 1.0) {
            return bartMachineRegressionMultThread.max(dArray);
        }
        double d2 = (double)(n - 1) * d + 1.0;
        int n2 = (int)Math.floor(d2);
        double d3 = bartMachineRegressionMultThread.select(dArray, n2 - 1);
        if (d2 == (double)n2) {
            return d3;
        }
        double d4 = bartMachineRegressionMultThread.select(dArray, n2);
        return d3 + (d2 - (double)n2) * (d4 - d3);
    }

    private static double select(double[] dArray, int n) {
        int n2 = 0;
        int n3 = dArray.length - 1;
        while (n2 != n3) {
            int n4 = n2 + (n3 - n2 >>> 1);
            if (n == (n4 = bartMachineRegressionMultThread.partition(dArray, n2, n3, n4))) {
                return dArray[n];
            }
            if (n < n4) {
                n3 = n4 - 1;
                continue;
            }
            n2 = n4 + 1;
        }
        return dArray[n2];
    }

    private static int partition(double[] dArray, int n, int n2, int n3) {
        double d = dArray[n3];
        bartMachineRegressionMultThread.swap(dArray, n3, n2);
        int n4 = n;
        for (int i = n; i < n2; ++i) {
            if (!(dArray[i] < d)) continue;
            bartMachineRegressionMultThread.swap(dArray, n4, i);
            ++n4;
        }
        bartMachineRegressionMultThread.swap(dArray, n2, n4);
        return n4;
    }

    private static void swap(double[] dArray, int n, int n2) {
        double d = dArray[n];
        dArray[n] = dArray[n2];
        dArray[n2] = d;
    }

    private static double min(double[] dArray) {
        double d = Double.POSITIVE_INFINITY;
        for (double d2 : dArray) {
            if (!(d2 < d)) continue;
            d = d2;
        }
        return d;
    }

    private static double max(double[] dArray) {
        double d = Double.NEGATIVE_INFINITY;
        for (double d2 : dArray) {
            if (!(d2 > d)) continue;
            d = d2;
        }
        return d;
    }

    private static void sortInPlace(double[] dArray) {
        if (dArray.length >= 16384) {
            Arrays.parallelSort(dArray);
        } else {
            Arrays.sort(dArray);
        }
    }

    private static double[] getThreadLocalBuffer(int n) {
        double[] dArray = EVAL_BUFFER.get();
        if (dArray.length < n) {
            dArray = new double[n];
            EVAL_BUFFER.set(dArray);
        }
        return dArray;
    }

    private static void addInPlace(double[] dArray, double[] dArray2) {
        int n;
        VectorSpecies vectorSpecies = DoubleVector.SPECIES_PREFERRED;
        int n2 = vectorSpecies.loopBound(dArray.length);
        for (n = 0; n < n2; n += vectorSpecies.length()) {
            DoubleVector doubleVector = DoubleVector.fromArray((VectorSpecies)vectorSpecies, (double[])dArray, (int)n);
            DoubleVector doubleVector2 = DoubleVector.fromArray((VectorSpecies)vectorSpecies, (double[])dArray2, (int)n);
            doubleVector.add((Vector)doubleVector2).intoArray(dArray, n);
        }
        while (n < dArray.length) {
            int n3 = n;
            dArray[n3] = dArray[n3] + dArray2[n];
            ++n;
        }
    }

    public double[] getGibbsSamplesSigsqs() {
        DoubleArrayList doubleArrayList = new DoubleArrayList(this.num_gibbs_total_iterations);
        for (int i = 0; i < this.num_cores; ++i) {
            DoubleArrayList doubleArrayList2 = new DoubleArrayList(this.bart_gibbs_chain_threads[i].getGibbsSamplesSigsqs());
            if (i == 0) {
                doubleArrayList.addAll((DoubleList)doubleArrayList2);
                continue;
            }
            doubleArrayList.addAll(doubleArrayList2.subList(this.num_gibbs_burn_in, this.total_iterations_multithreaded));
        }
        return doubleArrayList.toDoubleArray();
    }

    public boolean[][] getAcceptRejectMHsBurnin() {
        boolean[][] blArray = this.bart_gibbs_chain_threads[0].getAcceptRejectMH();
        boolean[][] blArray2 = new boolean[this.num_gibbs_burn_in][this.num_trees];
        for (int i = 1; i < this.num_gibbs_burn_in + 1; ++i) {
            blArray2[i - 1] = blArray[i];
        }
        return blArray2;
    }

    public boolean[][] getAcceptRejectMHsAfterBurnIn(int n) {
        boolean[][] blArray = this.bart_gibbs_chain_threads[n - 1].getAcceptRejectMH();
        boolean[][] blArray2 = new boolean[this.total_iterations_multithreaded - this.num_gibbs_burn_in][this.num_trees];
        for (int i = this.num_gibbs_burn_in; i < this.total_iterations_multithreaded; ++i) {
            blArray2[i - this.num_gibbs_burn_in] = blArray[i];
        }
        return blArray2;
    }

    public int[][] getCountsForAllAttribute(String string) {
        int[][] nArray = new int[this.num_gibbs_total_iterations - this.num_gibbs_burn_in][this.p];
        for (int i = 0; i < this.num_gibbs_total_iterations - this.num_gibbs_burn_in; ++i) {
            bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i];
            int[] nArray2 = new int[this.p];
            for (bartMachineTreeNode bartMachineTreeNode2 : bartMachineTreeNodeArray) {
                if (string.equals("splits")) {
                    nArray2 = Tools.add_arrays(nArray2, bartMachineTreeNode2.attributeSplitCounts());
                    continue;
                }
                if (!string.equals("trees")) continue;
                nArray2 = Tools.binary_add_arrays(nArray2, bartMachineTreeNode2.attributeSplitCounts());
            }
            nArray[i] = nArray2;
        }
        return nArray;
    }

    public double[] getAttributeProps(String string) {
        int[][] nArray = this.getCountsForAllAttribute(string);
        double[] dArray = new double[this.p];
        for (int i = 0; i < this.num_gibbs_total_iterations - this.num_gibbs_burn_in; ++i) {
            dArray = Tools.add_arrays(dArray, nArray[i]);
        }
        Tools.normalize_array(dArray);
        return dArray;
    }

    public int[][] getInteractionCounts() {
        int[][] nArray = new int[this.p][this.p];
        for (int i = 0; i < this.gibbs_samples_of_bart_trees_after_burn_in.length; ++i) {
            bartMachineTreeNode[] bartMachineTreeNodeArray;
            for (bartMachineTreeNode bartMachineTreeNode2 : bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees_after_burn_in[i]) {
                HashSet<UnorderedPair<Integer>> hashSet = new HashSet<UnorderedPair<Integer>>(this.p * this.p);
                bartMachineTreeNode2.findInteractions(hashSet);
                for (UnorderedPair<Integer> unorderedPair : hashSet) {
                    int[] nArray2 = nArray[unorderedPair.getFirst()];
                    int n = unorderedPair.getSecond();
                    nArray2[n] = nArray2[n] + 1;
                }
            }
        }
        return nArray;
    }

    @Override
    protected void FlushData() {
        for (int i = 0; i < this.num_cores; ++i) {
            this.bart_gibbs_chain_threads[i].FlushData();
        }
    }

    @Override
    public double Evaluate(double[] dArray) {
        return this.EvaluateViaSampAvg(dArray, 1);
    }

    @Override
    public double Evaluate(double[] dArray, int n) {
        return this.EvaluateViaSampAvg(dArray, n);
    }

    public double EvaluateViaSampAvg(double[] dArray, int n) {
        double[][] dArray2 = new double[1][dArray.length];
        dArray2[0] = dArray;
        double[][] dArray3 = this.getGibbsSamplesForPrediction(dArray2, n);
        return StatToolbox.sample_average(dArray3[0]);
    }

    public double EvaluateViaSampMed(double[] dArray, int n) {
        double[][] dArray2 = new double[1][dArray.length];
        dArray2[0] = dArray;
        double[][] dArray3 = this.getGibbsSamplesForPrediction(dArray2, n);
        return StatToolbox.sample_average(dArray3[0]);
    }

    @Override
    public double[] Evaluate(double[][] dArray, int n) {
        double[][] dArray2 = this.getGibbsSamplesForPrediction(dArray, n);
        double[] dArray3 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray3[i] = StatToolbox.sample_average(dArray2[i]);
        }
        return dArray3;
    }

    public int[][] getDepthsForTreesInGibbsSampAfterBurnIn(int n) {
        return this.bart_gibbs_chain_threads[n - 1].getDepthsForTrees(this.num_gibbs_burn_in, this.total_iterations_multithreaded);
    }

    public int[][] getNumNodesAndLeavesForTreesInGibbsSampAfterBurnIn(int n) {
        return this.bart_gibbs_chain_threads[n - 1].getNumNodesAndLeavesForTrees(this.num_gibbs_burn_in, this.total_iterations_multithreaded);
    }

    @Override
    public void setData(ArrayList<double[]> arrayList) {
        this.X_y = arrayList;
        this.n = arrayList.size();
        this.p = arrayList.get(0).length - 1;
    }

    public void printTreeIllustations() {
        this.tree_illust = true;
    }

    public void setCovSplitPrior(double[] dArray) {
        this.cov_split_prior = dArray;
    }

    public void intializeInteractionConstraints(int n) {
        this.interaction_constraints = new HashMap(n);
    }

    public void addInteractionConstraint(int n, int[] nArray) {
        if (this.interaction_constraints.get(n) == null) {
            this.interaction_constraints.put(n, new IntOpenHashSet());
        }
        IntOpenHashSet intOpenHashSet = this.interaction_constraints.get(n);
        for (int n2 : nArray) {
            intOpenHashSet.add(n2);
        }
    }

    public void setNumGibbsBurnIn(int n) {
        this.num_gibbs_burn_in = n;
    }

    public void setNumTrees(int n) {
        this.num_trees = n;
    }

    public void setSampleVarY(double d) {
        this.sample_var_y = d;
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    public void setBeta(double d) {
        this.beta = d;
    }

    public void setK(double d) {
        this.hyper_k = d;
    }

    public void setQ(double d) {
        this.hyper_q = d;
    }

    public void setNU(double d) {
        this.hyper_nu = d;
    }

    public void setProbGrow(double d) {
        this.prob_grow = d;
    }

    public void setProbPrune(double d) {
        this.prob_prune = d;
    }

    public void setVerbose(boolean bl) {
        this.verbose = bl;
    }

    public void setSeed(int n) {
        this.seed = n;
        StatToolbox.setSeed(n);
    }

    public void setNumCores(int n) {
        this.num_cores = n;
    }

    public void setMemCacheForSpeed(boolean bl) {
        this.mem_cache_for_speed = bl;
    }

    public void setFlushIndicesToSaveRAM(boolean bl) {
        this.flush_indices_to_save_ram = bl;
    }

    public void setUseXoshiro(boolean bl) {
        this.use_xoshiro = bl;
    }

    @Override
    public void StopBuilding() {
    }

    public bartMachineTreeNode[] extractRawNodeInformation(int n) {
        bartMachineTreeNode[] bartMachineTreeNodeArray = new bartMachineTreeNode[this.num_trees];
        for (int i = 0; i < this.num_trees; ++i) {
            bartMachineTreeNodeArray[i] = this.gibbs_samples_of_bart_trees_after_burn_in[n][i];
        }
        return bartMachineTreeNodeArray;
    }
}

