/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify.tests;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.NaiveBayes;
import cc.mallet.classify.NaiveBayesTrainer;
import cc.mallet.pipe.CharSequence2TokenSequence;
import cc.mallet.pipe.FeatureSequence2FeatureVector;
import cc.mallet.pipe.Input2CharSequence;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2Label;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.pipe.TokenSequenceRemoveStopwords;
import cc.mallet.pipe.iterator.ArrayIterator;
import cc.mallet.pipe.iterator.FileIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Multinomial;
import cc.mallet.util.Randoms;
import java.io.File;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

public class TestNaiveBayes
extends TestCase {
    public TestNaiveBayes(String name) {
        super(name);
    }

    public void testNonTrained() {
        Alphabet fdict = new Alphabet();
        System.out.println("fdict.size=" + fdict.size());
        LabelAlphabet ldict = new LabelAlphabet();
        Multinomial.LaplaceEstimator me1 = new Multinomial.LaplaceEstimator(fdict);
        Multinomial.LaplaceEstimator me2 = new Multinomial.LaplaceEstimator(fdict);
        ldict.lookupIndex("sports");
        ldict.lookupIndex("politics");
        ldict.stopGrowth();
        System.out.println("ldict.size=" + ldict.size());
        Multinomial prior = new Multinomial(new double[]{0.5, 0.5}, ldict);
        me1.increment("win", 5.0);
        me1.increment("puck", 5.0);
        me1.increment("team", 5.0);
        System.out.println("fdict.size=" + fdict.size());
        me2.increment("win", 5.0);
        me2.increment("speech", 5.0);
        me2.increment("vote", 5.0);
        Multinomial sports = ((Multinomial.Estimator)me1).estimate();
        Multinomial politics = ((Multinomial.Estimator)me2).estimate();
        NaiveBayes c = new NaiveBayes((Pipe)new Noop(fdict, ldict), prior, new Multinomial[]{sports, politics});
        Instance inst = c.getInstancePipe().instanceFrom(new Instance(new FeatureVector(fdict, new Object[]{"speech", "win"}, new double[]{1.0, 1.0}), ldict.lookupLabel("politics"), null, null));
        System.out.println("inst.data = " + inst.getData());
        Classification cf = ((Classifier)c).classify(inst);
        LabelVector l = (LabelVector)cf.getLabeling();
        System.out.println("l.getBestIndex=" + l.getBestIndex());
        TestNaiveBayes.assertTrue((cf.getLabeling().getBestLabel() == ldict.lookupLabel("politics") ? 1 : 0) != 0);
        TestNaiveBayes.assertTrue((cf.getLabeling().getBestValue() > 0.6 ? 1 : 0) != 0);
    }

    public void testStringTrained() {
        Object[] africaTraining = new String[]{"on the plains of africa the lions roar", "in swahili ngoma means to dance", "nelson mandela became president of south africa", "the saraha dessert is expanding"};
        Object[] asiaTraining = new String[]{"panda bears eat bamboo", "china's one child policy has resulted in a surplus of boys", "tigers live in the jungle"};
        InstanceList instances = new InstanceList(new SerialPipes(new Pipe[]{new Target2Label(), new CharSequence2TokenSequence(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector()}));
        instances.addThruPipe(new ArrayIterator(africaTraining, (Object)"africa"));
        instances.addThruPipe(new ArrayIterator(asiaTraining, (Object)"asia"));
        NaiveBayes c = new NaiveBayesTrainer().train(instances);
        Classification cf = c.classify("nelson mandela never eats lions");
        TestNaiveBayes.assertTrue((cf.getLabeling().getBestLabel() == ((LabelAlphabet)instances.getTargetAlphabet()).lookupLabel("africa") ? 1 : 0) != 0);
    }

    public void testRandomTrained() {
        InstanceList ilist = new InstanceList(new Randoms(1), 10, 2);
        NaiveBayes c = new NaiveBayesTrainer().train(ilist);
        int numCorrect = 0;
        for (int i = 0; i < ilist.size(); ++i) {
            Instance inst = (Instance)ilist.get(i);
            Classification cf = ((Classifier)c).classify(inst);
            cf.print();
            if (cf.getLabeling().getBestLabel() != inst.getLabeling().getBestLabel()) continue;
            ++numCorrect;
        }
        System.out.println("Accuracy on training set = " + (double)numCorrect / (double)ilist.size());
    }

    public void testIncrementallyTrainedGrowingAlphabets() {
        System.out.println("testIncrementallyTrainedGrowingAlphabets");
        String[] args = new String[]{"src/cc/mallet/classify/tests/NaiveBayesData/learn/a", "src/cc/mallet/classify/tests/NaiveBayesData/learn/b"};
        File[] directories = new File[args.length];
        for (int i = 0; i < args.length; ++i) {
            directories[i] = new File(args[i]);
        }
        SerialPipes instPipe = new SerialPipes(new Pipe[]{new Target2Label(), new Input2CharSequence(), new CharSequence2TokenSequence(), new TokenSequenceLowercase(), new TokenSequenceRemoveStopwords(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector()});
        InstanceList instList = new InstanceList(instPipe);
        instList.addThruPipe(new FileIterator(directories, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 1");
        NaiveBayesTrainer trainer = new NaiveBayesTrainer();
        NaiveBayes classifier = trainer.trainIncremental(instList);
        String[] t2directories = new String[]{"src/cc/mallet/classify/tests/NaiveBayesData/learn/b"};
        System.out.println("data alphabet size " + instList.getDataAlphabet().size());
        System.out.println("target alphabet size " + instList.getTargetAlphabet().size());
        InstanceList instList2 = new InstanceList(instPipe);
        instList2.addThruPipe(new FileIterator(t2directories, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 2");
        System.out.println("data alphabet size " + instList2.getDataAlphabet().size());
        System.out.println("target alphabet size " + instList2.getTargetAlphabet().size());
        NaiveBayes classifier2 = trainer.trainIncremental(instList2);
    }

    public void testIncrementallyTrained() {
        System.out.println("testIncrementallyTrained");
        String[] args = new String[]{"src/cc/mallet/classify/tests/NaiveBayesData/learn/a", "src/cc/mallet/classify/tests/NaiveBayesData/learn/b"};
        File[] directories = new File[args.length];
        for (int i = 0; i < args.length; ++i) {
            directories[i] = new File(args[i]);
        }
        SerialPipes instPipe = new SerialPipes(new Pipe[]{new Target2Label(), new Input2CharSequence(), new CharSequence2TokenSequence(), new TokenSequenceLowercase(), new TokenSequenceRemoveStopwords(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector()});
        InstanceList instList = new InstanceList(instPipe);
        instList.addThruPipe(new FileIterator(directories, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 1");
        NaiveBayesTrainer trainer = new NaiveBayesTrainer();
        NaiveBayes classifier = trainer.trainIncremental(instList);
        Classification initialClassification = classifier.classify("Hello Everybody");
        Classification initial2Classification = classifier.classify("Goodbye now");
        System.out.println("Initial Classification = ");
        initialClassification.print();
        initial2Classification.print();
        System.out.println("data alphabet " + classifier.getAlphabet());
        System.out.println("label alphabet " + classifier.getLabelAlphabet());
        String[] t2directories = new String[]{"src/cc/mallet/classify/tests/NaiveBayesData/learn/b"};
        System.out.println("data alphabet size " + instList.getDataAlphabet().size());
        System.out.println("target alphabet size " + instList.getTargetAlphabet().size());
        InstanceList instList2 = new InstanceList(instPipe);
        instList2.addThruPipe(new FileIterator(t2directories, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 2");
        System.out.println("data alphabet size " + instList2.getDataAlphabet().size());
        System.out.println("target alphabet size " + instList2.getTargetAlphabet().size());
        NaiveBayes classifier2 = trainer.trainIncremental(instList2);
    }

    public void testEmptyStringBug() {
        System.out.println("testEmptyStringBug");
        String[] args = new String[]{"src/cc/mallet/classify/tests/NaiveBayesData/learn/a", "src/cc/mallet/classify/tests/NaiveBayesData/learn/b"};
        File[] directories = new File[args.length];
        for (int i = 0; i < args.length; ++i) {
            directories[i] = new File(args[i]);
        }
        SerialPipes instPipe = new SerialPipes(new Pipe[]{new Target2Label(), new Input2CharSequence(), new CharSequence2TokenSequence(), new TokenSequenceLowercase(), new TokenSequenceRemoveStopwords(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector()});
        InstanceList instList = new InstanceList(instPipe);
        instList.addThruPipe(new FileIterator(directories, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 1");
        NaiveBayesTrainer trainer = new NaiveBayesTrainer();
        NaiveBayes classifier = trainer.trainIncremental(instList);
        Classification initialClassification = classifier.classify("Hello Everybody");
        Classification initial2Classification = classifier.classify("Goodbye now");
        System.out.println("Initial Classification = ");
        initialClassification.print();
        initial2Classification.print();
        System.out.println("data alphabet " + classifier.getAlphabet());
        System.out.println("label alphabet " + classifier.getLabelAlphabet());
        String[] t2directories = new String[]{"src/cc/mallet/classify/tests/NaiveBayesData/learn/b"};
        System.out.println("data alphabet size " + instList.getDataAlphabet().size());
        System.out.println("target alphabet size " + instList.getTargetAlphabet().size());
        InstanceList instList2 = new InstanceList(instPipe);
        instList2.addThruPipe(new FileIterator(t2directories, FileIterator.STARTING_DIRECTORIES, true));
        System.out.println("Training 2");
        System.out.println("data alphabet size " + instList2.getDataAlphabet().size());
        System.out.println("target alphabet size " + instList2.getTargetAlphabet().size());
        NaiveBayes classifier2 = trainer.trainIncremental(instList2);
        Classification secondClassification = classifier.classify("Goodbye now");
        secondClassification.print();
    }

    static Test suite() {
        return new TestSuite(TestNaiveBayes.class);
    }

    protected void setUp() {
    }

    public static void main(String[] args) {
        TestRunner.run((Test)TestNaiveBayes.suite());
    }
}

