cc.mallet.classify.Trial Java Examples

Example #1
Source File:    From baleen with Apache License 2.0 5 votes vote down vote up
private void logAccuracyMetrics(Classifier classifier, Trial trial) {
  getMonitor().info("Accuracy: {}", trial.getAccuracy());
  for (String label : (String[]) classifier.getLabelAlphabet().toArray(new String[0])) {
    getMonitor().info("F1 for class '{}': {}", label, trial.getF1(label));
    getMonitor().info("Precision for class '{}' : {}", label, trial.getPrecision(label));
Example #2
Source File:    From baleen with Apache License 2.0 5 votes vote down vote up
private List<String> createRow(
    InstanceList training, InstanceList testing, String e, Classifier classifier, Trial trial) {
  List<String> row = new ArrayList<>();
  for (String label : (String[]) classifier.getLabelAlphabet().toArray(new String[0])) {
  return row;
Example #3
Source File:    From bluima with Apache License 2.0 5 votes vote down vote up
public static Trial testTrainSplit(InstanceList instances) {

        InstanceList[] instanceLists = instances.split(new Randoms(),
                new double[] { 0.9, 0.1, 0.0 });

        // LOG.debug("{} training instance, {} testing instances",
        // instanceLists[0].size(), instanceLists[1].size());

        ClassifierTrainer trainer = new MaxEntTrainer();
        Classifier classifier = trainer.train(instanceLists[TRAINING]);
        return new Trial(classifier, instanceLists[TESTING]);
Example #4
Source File:    From Machine-Learning-in-Java with MIT License 4 votes vote down vote up
public static void main(String[] args){
   	String stopListFilePath = "data/stoplists/en.txt";
   	String dataFolderPath = "data/ex6DataEmails/train";
   	String testFolderPath = "data/ex6DataEmails/test";
	ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
	pipeList.add(new Input2CharSequence("UTF-8"));
	Pattern tokenPattern = Pattern.compile("[\\p{L}\\p{N}_]+");
	pipeList.add(new CharSequence2TokenSequence(tokenPattern));
	pipeList.add(new TokenSequenceLowercase());
	pipeList.add(new TokenSequenceRemoveStopwords(new File(stopListFilePath), "utf-8", false, false, false));
	pipeList.add(new TokenSequence2FeatureSequence());
	pipeList.add(new FeatureSequence2FeatureVector());
	pipeList.add(new Target2Label());
	SerialPipes pipeline = new SerialPipes(pipeList);
	FileIterator folderIterator = new FileIterator(
			new File[] {new File(dataFolderPath)},
	         new TxtFilter(),

	InstanceList instances = new InstanceList(pipeline);
	ClassifierTrainer classifierTrainer = new NaiveBayesTrainer();
	Classifier classifier = classifierTrainer.train(instances);

	InstanceList testInstances = new InstanceList(classifier.getInstancePipe());
	folderIterator = new FileIterator(
			new File[] {new File(testFolderPath)},
	         new TxtFilter(),
       Trial trial = new Trial(classifier, testInstances);
       System.out.println("Accuracy: " + trial.getAccuracy());
       System.out.println("F1 for class 'spam': " + trial.getF1("spam"));

       System.out.println("Precision for class '" +
                          classifier.getLabelAlphabet().lookupLabel(1) + "': " +

       System.out.println("Recall for class '" +
                          classifier.getLabelAlphabet().lookupLabel(1) + "': " +


Example #5
Source File:    From bluima with Apache License 2.0 4 votes vote down vote up
public static void main(String[] args) {

        // pipe instances
        InstanceList instanceList = new InstanceList(
                new SerialPipes(getPipes()));
        FileIterator iterator = new FileIterator(new File[] { CORPUS },
                new TxtFilter(), LAST_DIRECTORY);

        // ////////////////////////////////////////////////////////////////
        // cross-validate
        double f1s = 0;
        for (int i = 0; i < trials; i++) {
            Trial trial = testTrainSplit(instanceList);
            System.out.println(join(new Object[] {//
                    i, trial.getPrecision(TESTING), trial.getRecall(TESTING),
                            trial.getF1(TESTING) }, "\t"));
            f1s += trial.getF1(TESTING);
        System.out.println("mean F1 = " + (f1s / (trials + 0d)));

        // ////////////////////////////////////////////////////////////////
        // train
        ClassifierTrainer trainer = new MaxEntTrainer();
        Classifier c = trainer.train(instanceList);

        String txt = "in the entorhinal cortex of the rat\n"
                + "II: phase relations between unit discharges and theta field potentials.\n"
                + "J. Comp. Neurol. 67, 502–509.\n"
                + "Alonso, A., and Klink, R. (1993).\n"
                + "Differential electroresponsiveness of\n"
                + "stellate and pyramidal-like cells of\n"
                + "medial entorhinal cortex layer II.\n"
                + "J. Neurophysiol. 70, 128–143.\n"
                + "Alonso, A., and Köhler, C. (1984).\n"
                + "A study of the reciprocal connections between the septum and the\n"
                + "entorhinal area using anterograde\n"
                + "and retrograde axonal transport\n"
                + "methods in the rat brain. J. Comp.\n"
                + "Neurol. 225, 327–343.\n"
                + "Alonso, A., and Llinás, R. (1989).\n"
                + "Subthreshold sodium-dependent\n"
                + "theta-like rhythmicity in stellate\n"
                + "cells of entorhinal cortex layer II.\n"
                + "Nature 342, 175–177.\n"
                + "Amaral, D. G., and Kurz, J. (1985).\n"
                + "An analysis of the origins of\n" + "";
        Classification classification = c.classify(c.getInstancePipe()
                .instanceFrom(new Instance(txt, null, null, null)));
        System.out.println("LABELL " + classification.getLabeling());

        try {
            ObjectOutputStream oos = new ObjectOutputStream(
                    new FileOutputStream("target/classifier_"
                            + currentTimeMillis() + ".model"));
        } catch (Exception e) {

        // //////////////////////////////////////////////////////////////////
        // train test
        for (String goldLabel : new String[] { "I", "O" }) {
            ClassifierTrainer trainer2 = new MaxEntTrainer();
            Classifier c2 = trainer2.train(instanceList);

            FileIterator iteratorI = new FileIterator(new File[] { new File(
                    CORPUS, "../annots1/" + goldLabel + "/") },
                    new TxtFilter(), LAST_DIRECTORY);
            Iterator<Instance> instancesI = c2.getInstancePipe()

            Histogram<String> h = new Histogram<String>();
            while (instancesI.hasNext()) {
                Instance inst =;
                Labeling labeling = c2.classify(inst).getLabeling();
                Label bestLabel = labeling.getBestLabel();

                // if (!bestLabel.toString().equals(goldLabel)) {
                // LOG.debug(
                // "\n\n\nMISSCLASSIFIED as {} but gold:{} :: "
                // + inst.getSource(), bestLabel, goldLabel);
                // }
            System.out.println("\nlabel " + goldLabel + "\n" + h);