/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes.net.search.global;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

public class GlobalScoreSearchAlgorithm
extends SearchAlgorithm {
    static final long serialVersionUID = 7341389867906199781L;
    BayesNet m_BayesNet;
    boolean m_bUseProb = true;
    int m_nNrOfFolds = 10;
    static final int LOOCV = 0;
    static final int KFOLDCV = 1;
    static final int CUMCV = 2;
    public static final Tag[] TAGS_CV_TYPE = new Tag[]{new Tag(0, "LOO-CV"), new Tag(1, "k-Fold-CV"), new Tag(2, "Cumulative-CV")};
    int m_nCVType = 0;

    public double calcScore(BayesNet bayesNet) throws Exception {
        switch (this.m_nCVType) {
            case 0: {
                return this.leaveOneOutCV(bayesNet);
            }
            case 2: {
                return this.cumulativeCV(bayesNet);
            }
            case 1: {
                return this.kFoldCV(bayesNet, this.m_nNrOfFolds);
            }
        }
        throw new Exception("Unrecognized cross validation type encountered: " + this.m_nCVType);
    }

    public double calcScoreWithExtraParent(int n, int n2) throws Exception {
        ParentSet parentSet = this.m_BayesNet.getParentSet(n);
        Instances instances = this.m_BayesNet.m_Instances;
        for (int i = 0; i < parentSet.getNrOfParents(); ++i) {
            if (parentSet.getParent(i) != n2) continue;
            return -1.0E100;
        }
        parentSet.addParent(n2, instances);
        double d = this.calcScore(this.m_BayesNet);
        parentSet.deleteLastParent(instances);
        return d;
    }

    public double calcScoreWithMissingParent(int n, int n2) throws Exception {
        ParentSet parentSet = this.m_BayesNet.getParentSet(n);
        Instances instances = this.m_BayesNet.m_Instances;
        if (!parentSet.contains(n2)) {
            return -1.0E100;
        }
        int n3 = parentSet.deleteParent(n2, instances);
        double d = this.calcScore(this.m_BayesNet);
        parentSet.addParent(n2, n3, instances);
        return d;
    }

    public double calcScoreWithReversedParent(int n, int n2) throws Exception {
        ParentSet parentSet = this.m_BayesNet.getParentSet(n);
        ParentSet parentSet2 = this.m_BayesNet.getParentSet(n2);
        Instances instances = this.m_BayesNet.m_Instances;
        if (!parentSet.contains(n2)) {
            return -1.0E100;
        }
        int n3 = parentSet.deleteParent(n2, instances);
        parentSet2.addParent(n, instances);
        double d = this.calcScore(this.m_BayesNet);
        parentSet2.deleteLastParent(instances);
        parentSet.addParent(n2, n3, instances);
        return d;
    }

    public double leaveOneOutCV(BayesNet bayesNet) throws Exception {
        this.m_BayesNet = bayesNet;
        double d = 0.0;
        double d2 = 0.0;
        Instances instances = bayesNet.m_Instances;
        bayesNet.estimateCPTs();
        for (int i = 0; i < instances.numInstances(); ++i) {
            Instance instance = instances.instance(i);
            instance.setWeight(-instance.weight());
            bayesNet.updateClassifier(instance);
            d += this.accuracyIncrease(instance);
            d2 += instance.weight();
            instance.setWeight(-instance.weight());
            bayesNet.updateClassifier(instance);
        }
        return d / d2;
    }

    public double cumulativeCV(BayesNet bayesNet) throws Exception {
        this.m_BayesNet = bayesNet;
        double d = 0.0;
        double d2 = 0.0;
        Instances instances = bayesNet.m_Instances;
        bayesNet.initCPTs();
        for (int i = 0; i < instances.numInstances(); ++i) {
            Instance instance = instances.instance(i);
            d += this.accuracyIncrease(instance);
            bayesNet.updateClassifier(instance);
            d2 += instance.weight();
        }
        return d / d2;
    }

    public double kFoldCV(BayesNet bayesNet, int n) throws Exception {
        this.m_BayesNet = bayesNet;
        double d = 0.0;
        double d2 = 0.0;
        Instances instances = bayesNet.m_Instances;
        bayesNet.estimateCPTs();
        int n2 = 0;
        int n3 = instances.numInstances() / n;
        int n4 = 1;
        while (n2 < instances.numInstances()) {
            Instance instance;
            int n5;
            for (n5 = n2; n5 < n3; ++n5) {
                instance = instances.instance(n5);
                instance.setWeight(-instance.weight());
                bayesNet.updateClassifier(instance);
            }
            for (n5 = n2; n5 < n3; ++n5) {
                instance = instances.instance(n5);
                instance.setWeight(-instance.weight());
                d += this.accuracyIncrease(instance);
                instance.setWeight(-instance.weight());
                d2 += instance.weight();
            }
            for (n5 = n2; n5 < n3; ++n5) {
                instance = instances.instance(n5);
                instance.setWeight(-instance.weight());
                bayesNet.updateClassifier(instance);
            }
            n2 = n3;
            n3 = ++n4 * instances.numInstances() / n;
        }
        return d / d2;
    }

    double accuracyIncrease(Instance instance) throws Exception {
        if (this.m_bUseProb) {
            double[] dArray = this.m_BayesNet.distributionForInstance(instance);
            return dArray[(int)instance.classValue()] * instance.weight();
        }
        if (this.m_BayesNet.classifyInstance(instance) == instance.classValue()) {
            return instance.weight();
        }
        return 0.0;
    }

    public boolean getUseProb() {
        return this.m_bUseProb;
    }

    public void setUseProb(boolean bl) {
        this.m_bUseProb = bl;
    }

    public void setCVType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_CV_TYPE) {
            this.m_nCVType = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getCVType() {
        return new SelectedTag(this.m_nCVType, TAGS_CV_TYPE);
    }

    public void setMarkovBlanketClassifier(boolean bl) {
        super.setMarkovBlanketClassifier(bl);
    }

    public boolean getMarkovBlanketClassifier() {
        return super.getMarkovBlanketClassifier();
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>();
        vector.addElement(new Option("\tApplies a Markov Blanket correction to the network structure, \n\tafter a network structure is learned. This ensures that all \n\tnodes in the network are part of the Markov blanket of the \n\tclassifier node.", "mbc", 0, "-mbc"));
        vector.addElement(new Option("\tScore type (LOO-CV,k-Fold-CV,Cumulative-CV)", "S", 1, "-S [LOO-CV|k-Fold-CV|Cumulative-CV]"));
        vector.addElement(new Option("\tUse probabilistic or 0/1 scoring.\n\t(default probabilistic scoring)", "Q", 0, "-Q"));
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        this.setMarkovBlanketClassifier(Utils.getFlag("mbc", stringArray));
        String string = Utils.getOption('S', stringArray);
        if (string.compareTo("LOO-CV") == 0) {
            this.setCVType(new SelectedTag(0, TAGS_CV_TYPE));
        }
        if (string.compareTo("k-Fold-CV") == 0) {
            this.setCVType(new SelectedTag(1, TAGS_CV_TYPE));
        }
        if (string.compareTo("Cumulative-CV") == 0) {
            this.setCVType(new SelectedTag(2, TAGS_CV_TYPE));
        }
        this.setUseProb(!Utils.getFlag('Q', stringArray));
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = super.getOptions();
        String[] stringArray2 = new String[4 + stringArray.length];
        int n = 0;
        if (this.getMarkovBlanketClassifier()) {
            stringArray2[n++] = "-mbc";
        }
        stringArray2[n++] = "-S";
        switch (this.m_nCVType) {
            case 0: {
                stringArray2[n++] = "LOO-CV";
                break;
            }
            case 1: {
                stringArray2[n++] = "k-Fold-CV";
                break;
            }
            case 2: {
                stringArray2[n++] = "Cumulative-CV";
            }
        }
        if (this.getUseProb()) {
            stringArray2[n++] = "-Q";
        }
        for (int i = 0; i < stringArray.length; ++i) {
            stringArray2[n++] = stringArray[i];
        }
        while (n < stringArray2.length) {
            stringArray2[n++] = "";
        }
        return stringArray2;
    }

    public String CVTypeTipText() {
        return "Select cross validation strategy to be used in searching for networks.LOO-CV = Leave one out cross validation\nk-Fold-CV = k fold cross validation\nCumulative-CV = cumulative cross validation.";
    }

    public String useProbTipText() {
        return "If set to true, the probability of the class if returned in the estimate of the accuracy. If set to false, the accuracy estimate is only increased if the classifier returns exactly the correct class.";
    }

    public String globalInfo() {
        return "This Bayes Network learning algorithm uses cross validation to estimate classification accuracy.";
    }

    public String markovBlanketClassifierTipText() {
        return super.markovBlanketClassifierTipText();
    }
}

