cc.mallet.classify.ClassifierTrainer Java Examples

The following examples show how to use cc.mallet.classify.ClassifierTrainer. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: EngineMBMalletClass.java    From gateplugin-LearningFramework with GNU Lesser General Public License v2.1 5 votes vote down vote up
@Override
public void trainModel(File dataDirectory, String instanceType, String parms) {
  //System.err.println("EngineMalletClass.trainModel: trainer="+trainer);
  //System.err.println("EngineMalletClass.trainModel: CR="+corpusRepresentation);
  
  model=((ClassifierTrainer) trainer).train(corpusRepresentation.getRepresentationMallet());
  updateInfo();    
  SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
  info.modelWhenTrained = sdf.format(new Date());    
  info.algorithmParameters = parms;
  info.save(dataDirectory);    
  featureInfo.save(dataDirectory);
  
}
 
Example #2
Source File: ClassifierTrainerFactory.java    From baleen with Apache License 2.0 5 votes vote down vote up
/** {@link ClassifierTrainer} according to the specification */
@SuppressWarnings("unchecked")
public <T extends Classifier> ClassifierTrainer<T> createTrainer() {
  String[] fields = trainerDescriptor.split(",");
  ClassifierTrainer<T> trainer =
      (ClassifierTrainer<T>) createTrainer(resolveTrainerClassName(fields[0]));
  setParameterValues(fields, trainer);
  return trainer;
}
 
Example #3
Source File: ClassifierTrainerFactory.java    From baleen with Apache License 2.0 5 votes vote down vote up
private ClassifierTrainer<?> createTrainer(String className) {
  try {
    return (ClassifierTrainer<?>)
        BuilderUtils.getClassFromString(className, "cc.mallet.classify").newInstance();
  } catch (ClassCastException
      | InstantiationException
      | IllegalAccessException
      | InvalidParameterException e) {
    throw new IllegalArgumentException(String.format("Unknown trainer %s", className), e);
  }
}
 
Example #4
Source File: ClassifierTrainerTest.java    From baleen with Apache License 2.0 5 votes vote down vote up
@Test
public void testFactory() throws ResourceInitializationException {
  ClassifierTrainerFactory factory =
      new ClassifierTrainerFactory("MaxEnt,gaussianPriorVariance=10.0,numIterations=20");
  ClassifierTrainer<?> trainer = factory.createTrainer();
  assertNotNull(trainer);
}
 
Example #5
Source File: ReferencesClassifierTrainer.java    From bluima with Apache License 2.0 5 votes vote down vote up
public static Trial testTrainSplit(InstanceList instances) {

        InstanceList[] instanceLists = instances.split(new Randoms(),
                new double[] { 0.9, 0.1, 0.0 });

        // LOG.debug("{} training instance, {} testing instances",
        // instanceLists[0].size(), instanceLists[1].size());

        @SuppressWarnings("rawtypes")
        ClassifierTrainer trainer = new MaxEntTrainer();
        Classifier classifier = trainer.train(instanceLists[TRAINING]);
        return new Trial(classifier, instanceLists[TESTING]);
    }
 
Example #6
Source File: SpamDetector.java    From Machine-Learning-in-Java with MIT License 4 votes vote down vote up
public static void main(String[] args){
	
   	String stopListFilePath = "data/stoplists/en.txt";
   	String dataFolderPath = "data/ex6DataEmails/train";
   	String testFolderPath = "data/ex6DataEmails/test";
   	
	ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
	pipeList.add(new Input2CharSequence("UTF-8"));
	Pattern tokenPattern = Pattern.compile("[\\p{L}\\p{N}_]+");
	pipeList.add(new CharSequence2TokenSequence(tokenPattern));
	pipeList.add(new TokenSequenceLowercase());
	pipeList.add(new TokenSequenceRemoveStopwords(new File(stopListFilePath), "utf-8", false, false, false));
	pipeList.add(new TokenSequence2FeatureSequence());
	pipeList.add(new FeatureSequence2FeatureVector());
	pipeList.add(new Target2Label());
	SerialPipes pipeline = new SerialPipes(pipeList);
	
	FileIterator folderIterator = new FileIterator(
			new File[] {new File(dataFolderPath)},
	         new TxtFilter(),
	         FileIterator.LAST_DIRECTORY);

	
	InstanceList instances = new InstanceList(pipeline);
	
	instances.addThruPipe(folderIterator);
	
	ClassifierTrainer classifierTrainer = new NaiveBayesTrainer();
	Classifier classifier = classifierTrainer.train(instances);

	InstanceList testInstances = new InstanceList(classifier.getInstancePipe());
	folderIterator = new FileIterator(
			new File[] {new File(testFolderPath)},
	         new TxtFilter(),
	         FileIterator.LAST_DIRECTORY);
       testInstances.addThruPipe(folderIterator);
       
       Trial trial = new Trial(classifier, testInstances);
       
       System.out.println("Accuracy: " + trial.getAccuracy());
       System.out.println("F1 for class 'spam': " + trial.getF1("spam"));

       System.out.println("Precision for class '" +
                          classifier.getLabelAlphabet().lookupLabel(1) + "': " +
                          trial.getPrecision(1));

       System.out.println("Recall for class '" +
                          classifier.getLabelAlphabet().lookupLabel(1) + "': " +
                          trial.getRecall(1));

	
	

}
 
Example #7
Source File: ClassifierTrainerFactory.java    From baleen with Apache License 2.0 4 votes vote down vote up
@VisibleForTesting
protected void setParameterValues(String[] fields, ClassifierTrainer<?> trainer) {
  for (int i = 1; i < fields.length; i++) {
    setParameterValue(trainer, fields[i]);
  }
}
 
Example #8
Source File: ParameterClassifierTrainerTest.java    From baleen with Apache License 2.0 4 votes vote down vote up
@Test
public void testFactory() throws ResourceInitializationException {
  ClassifierTrainerFactory factory = new ClassifierTrainerFactory(trainerDescriptor);
  ClassifierTrainer<?> trainer = factory.createTrainer();
  assertNotNull(trainer);
}
 
Example #9
Source File: ReferencesClassifierTrainer.java    From bluima with Apache License 2.0 4 votes vote down vote up
public static void main(String[] args) {

        // pipe instances
        InstanceList instanceList = new InstanceList(
                new SerialPipes(getPipes()));
        FileIterator iterator = new FileIterator(new File[] { CORPUS },
                new TxtFilter(), LAST_DIRECTORY);
        instanceList.addThruPipe(iterator);

        // ////////////////////////////////////////////////////////////////
        // cross-validate
        System.out.println("trial\tprec\trecall\tF-score");
        double f1s = 0;
        for (int i = 0; i < trials; i++) {
            Trial trial = testTrainSplit(instanceList);
            System.out.println(join(new Object[] {//
                    i, trial.getPrecision(TESTING), trial.getRecall(TESTING),
                            trial.getF1(TESTING) }, "\t"));
            f1s += trial.getF1(TESTING);
        }
        System.out.println("mean F1 = " + (f1s / (trials + 0d)));

        // ////////////////////////////////////////////////////////////////
        // train
        ClassifierTrainer trainer = new MaxEntTrainer();
        Classifier c = trainer.train(instanceList);

        String txt = "in the entorhinal cortex of the rat\n"
                + "II: phase relations between unit discharges and theta field potentials.\n"
                + "J. Comp. Neurol. 67, 502–509.\n"
                + "Alonso, A., and Klink, R. (1993).\n"
                + "Differential electroresponsiveness of\n"
                + "stellate and pyramidal-like cells of\n"
                + "medial entorhinal cortex layer II.\n"
                + "J. Neurophysiol. 70, 128–143.\n"
                + "Alonso, A., and Köhler, C. (1984).\n"
                + "A study of the reciprocal connections between the septum and the\n"
                + "entorhinal area using anterograde\n"
                + "and retrograde axonal transport\n"
                + "methods in the rat brain. J. Comp.\n"
                + "Neurol. 225, 327–343.\n"
                + "Alonso, A., and Llinás, R. (1989).\n"
                + "Subthreshold sodium-dependent\n"
                + "theta-like rhythmicity in stellate\n"
                + "cells of entorhinal cortex layer II.\n"
                + "Nature 342, 175–177.\n"
                + "Amaral, D. G., and Kurz, J. (1985).\n"
                + "An analysis of the origins of\n" + "";
        Classification classification = c.classify(c.getInstancePipe()
                .instanceFrom(new Instance(txt, null, null, null)));
        System.out.println("LABELL " + classification.getLabeling());
        c.print();

        try {
            ObjectOutputStream oos = new ObjectOutputStream(
                    new FileOutputStream("target/classifier_"
                            + currentTimeMillis() + ".model"));
            oos.writeObject(c);
            oos.close();
        } catch (Exception e) {
            e.fillInStackTrace();
        }

        // //////////////////////////////////////////////////////////////////
        // train test
        for (String goldLabel : new String[] { "I", "O" }) {
            ClassifierTrainer trainer2 = new MaxEntTrainer();
            Classifier c2 = trainer2.train(instanceList);

            FileIterator iteratorI = new FileIterator(new File[] { new File(
                    CORPUS, "../annots1/" + goldLabel + "/") },
                    new TxtFilter(), LAST_DIRECTORY);
            Iterator<Instance> instancesI = c2.getInstancePipe()
                    .newIteratorFrom(iteratorI);

            Histogram<String> h = new Histogram<String>();
            while (instancesI.hasNext()) {
                Instance inst = instancesI.next();
                Labeling labeling = c2.classify(inst).getLabeling();
                Label bestLabel = labeling.getBestLabel();
                h.add(bestLabel.toString());

                // if (!bestLabel.toString().equals(goldLabel)) {
                // LOG.debug(
                // "\n\n\nMISSCLASSIFIED as {} but gold:{} :: "
                // + inst.getSource(), bestLabel, goldLabel);
                // }
            }
            System.out.println("\nlabel " + goldLabel + "\n" + h);
        }
    }