/*
 * Decompiled with CFR 0.152.
 */
package bartMachine;

import OpenSourceExtensions.StatUtil;
import bartMachine.bartMachineClassification;
import bartMachine.bartMachineRegressionMultThread;
import java.io.Serializable;

public class bartMachineClassificationMultThread
extends bartMachineRegressionMultThread
implements Serializable {
    private static double DEFAULT_CLASSIFICATION_RULE = 0.5;
    private double classification_rule;

    @Override
    protected void SetupBARTModels() {
        this.bart_gibbs_chain_threads = new bartMachineClassification[this.num_cores];
        for (int i = 0; i < this.num_cores; ++i) {
            this.SetupBartModel(new bartMachineClassification(), i);
        }
        this.classification_rule = DEFAULT_CLASSIFICATION_RULE;
    }

    @Override
    public double Evaluate(double[] dArray, int n) {
        return this.EvaluateViaSampAvg(dArray, n) > this.classification_rule ? 1.0 : 0.0;
    }

    @Override
    protected double[][] getGibbsSamplesForPrediction(double[][] dArray, int n) {
        double[][] dArray2 = super.getGibbsSamplesForPrediction(dArray, n);
        double[][] dArray3 = new double[dArray2.length][dArray2[0].length];
        for (int i = 0; i < dArray2.length; ++i) {
            for (int j = 0; j < dArray2[0].length; ++j) {
                dArray3[i][j] = StatUtil.normal_cdf(dArray2[i][j]);
            }
        }
        return dArray3;
    }

    public void setClassificationRule(double d) {
        this.classification_rule = d;
    }
}

