Java Code Examples for weka.core.Instances#stratify()
The following examples show how to use
weka.core.Instances#stratify() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: AttributeSelection.java From tsml with GNU General Public License v3.0 | 7 votes |
/** * Perform a cross validation for attribute selection. With subset * evaluators the number of times each attribute is selected over * the cross validation is reported. For attribute evaluators, the * average merit and average ranking + std deviation is reported for * each attribute. * * @return the results of cross validation as a String * @exception Exception if an error occurs during cross validation */ public String CrossValidateAttributes () throws Exception { Instances cvData = new Instances(m_trainInstances); Instances train; Random random = new Random(m_seed); cvData.randomize(random); if (!(m_ASEvaluator instanceof UnsupervisedSubsetEvaluator) && !(m_ASEvaluator instanceof UnsupervisedAttributeEvaluator)) { if (cvData.classAttribute().isNominal()) { cvData.stratify(m_numFolds); } } for (int i = 0; i < m_numFolds; i++) { // Perform attribute selection train = cvData.trainCV(m_numFolds, i, random); selectAttributesCVSplit(train); } return CVResultsString(); }
Example 2
Source File: EvaluationUtils.java From tsml with GNU General Public License v3.0 | 6 votes |
/** * Generate a bunch of predictions ready for processing, by performing a * cross-validation on the supplied dataset. * * @param classifier the Classifier to evaluate * @param data the dataset * @param numFolds the number of folds in the cross-validation. * @exception Exception if an error occurs */ public FastVector getCVPredictions(Classifier classifier, Instances data, int numFolds) throws Exception { FastVector predictions = new FastVector(); Instances runInstances = new Instances(data); Random random = new Random(m_Seed); runInstances.randomize(random); if (runInstances.classAttribute().isNominal() && (numFolds > 1)) { runInstances.stratify(numFolds); } int inst = 0; for (int fold = 0; fold < numFolds; fold++) { Instances train = runInstances.trainCV(numFolds, fold, random); Instances test = runInstances.testCV(numFolds, fold); FastVector foldPred = getTrainTestPredictions(classifier, train, test); predictions.appendElements(foldPred); } return predictions; }
Example 3
Source File: WekaUtilTester.java From AILibs with GNU Affero General Public License v3.0 | 6 votes |
@Test public void checkSplit() throws Exception { Instances inst = new Instances(new BufferedReader(new FileReader(VOWEL_ARFF))); inst.setClassIndex(inst.numAttributes() - 1); for (Classifier c : this.portfolio) { /* eval for CV */ inst.stratify(10); Instances train = inst.trainCV(10, 0); Instances test = inst.testCV(10, 0); Assert.assertEquals(train.size() + test.size(), inst.size()); Evaluation eval = new Evaluation(train); eval.crossValidateModel(c, inst, 10, new Random(0)); c.buildClassifier(train); eval.evaluateModel(c, test); System.out.println(eval.pctCorrect()); } }
Example 4
Source File: ThresholdSelector.java From tsml with GNU General Public License v3.0 | 5 votes |
/** * Collects the classifier predictions using the specified evaluation method. * * @param instances the set of <code>Instances</code> to generate * predictions for. * @param mode the evaluation mode. * @param numFolds the number of folds to use if not evaluating on the * full training set. * @return a <code>FastVector</code> containing the predictions. * @throws Exception if an error occurs generating the predictions. */ protected FastVector getPredictions(Instances instances, int mode, int numFolds) throws Exception { EvaluationUtils eu = new EvaluationUtils(); eu.setSeed(m_Seed); switch (mode) { case EVAL_TUNED_SPLIT: Instances trainData = null, evalData = null; Instances data = new Instances(instances); Random random = new Random(m_Seed); data.randomize(random); data.stratify(numFolds); // Make sure that both subsets contain at least one positive instance for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) { trainData = data.trainCV(numFolds, subsetIndex, random); evalData = data.testCV(numFolds, subsetIndex); if (checkForInstance(trainData) && checkForInstance(evalData)) { break; } } return eu.getTrainTestPredictions(m_Classifier, trainData, evalData); case EVAL_TRAINING_SET: return eu.getTrainTestPredictions(m_Classifier, instances, instances); case EVAL_CROSS_VALIDATION: return eu.getCVPredictions(m_Classifier, instances, numFolds); default: throw new RuntimeException("Unrecognized evaluation mode"); } }
Example 5
Source File: Stacking.java From tsml with GNU General Public License v3.0 | 5 votes |
/** * Buildclassifier selects a classifier from the set of classifiers * by minimising error on the training data. * * @param data the training data to be used for generating the * boosted classifier. * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { if (m_MetaClassifier == null) { throw new IllegalArgumentException("No meta classifier has been set"); } // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class Instances newData = new Instances(data); m_BaseFormat = new Instances(data, 0); newData.deleteWithMissingClass(); Random random = new Random(m_Seed); newData.randomize(random); if (newData.classAttribute().isNominal()) { newData.stratify(m_NumFolds); } // Create meta level generateMetaLevel(newData, random); // restart the executor pool because at the end of processing // a set of classifiers it gets shutdown to prevent the program // executing as a server super.buildClassifier(newData); // Rebuild all the base classifiers on the full training data buildClassifiers(newData); }
Example 6
Source File: LogisticBase.java From tsml with GNU General Public License v3.0 | 5 votes |
/** * Runs LogitBoost, determining the best number of iterations by cross-validation. * * @throws Exception if something goes wrong */ protected void performBoostingCV() throws Exception{ //completed iteration keeps track of the number of iterations that have been //performed in every fold (some might stop earlier than others). //Best iteration is selected only from these. int completedIterations = m_maxIterations; Instances allData = new Instances(m_train); allData.stratify(m_numFoldsBoosting); double[] error = new double[m_maxIterations + 1]; for (int i = 0; i < m_numFoldsBoosting; i++) { //split into training/test data in fold Instances train = allData.trainCV(m_numFoldsBoosting,i); Instances test = allData.testCV(m_numFoldsBoosting,i); //initialize LogitBoost m_numRegressions = 0; m_regressions = initRegressions(); //run LogitBoost iterations int iterations = performBoosting(train,test,error,completedIterations); if (iterations < completedIterations) completedIterations = iterations; } //determine iteration with minimum error over the folds int bestIteration = getBestIteration(error,completedIterations); //rebuild model on all of the training data m_numRegressions = 0; performBoosting(bestIteration); }
Example 7
Source File: Ridor.java From tsml with GNU General Public License v3.0 | 5 votes |
/** * Builds a single rule learner with REP dealing with 2 classes. * This rule learner always tries to predict the class with label * m_Class. * * @param instances the training data * @throws Exception if classifier can't be built successfully */ public void buildClassifier(Instances instances) throws Exception { m_ClassAttribute = instances.classAttribute(); if (!m_ClassAttribute.isNominal()) throw new UnsupportedClassTypeException(" Only nominal class, please."); if(instances.numClasses() != 2) throw new Exception(" Only 2 classes, please."); Instances data = new Instances(instances); if(Utils.eq(data.sumOfWeights(),0)) throw new Exception(" No training data."); data.deleteWithMissingClass(); if(Utils.eq(data.sumOfWeights(),0)) throw new Exception(" The class labels of all the training data are missing."); if(data.numInstances() < m_Folds) throw new Exception(" Not enough data for REP."); m_Antds = new FastVector(); /* Split data into Grow and Prune*/ m_Random = new Random(m_Seed); data.randomize(m_Random); data.stratify(m_Folds); Instances growData=data.trainCV(m_Folds, m_Folds-1, m_Random); Instances pruneData=data.testCV(m_Folds, m_Folds-1); grow(growData); // Build this rule prune(pruneData); // Prune this rule }
Example 8
Source File: WekaDeeplearning4jExamples.java From wekaDeeplearning4j with GNU General Public License v3.0 | 5 votes |
private static void dl4jResnet50() throws Exception { String folderPath = "src/test/resources/nominal/plant-seedlings-small"; ImageDirectoryLoader loader = new ImageDirectoryLoader(); loader.setInputDirectory(new File(folderPath)); Instances inst = loader.getDataSet(); inst.setClassIndex(1); Dl4jMlpClassifier classifier = new Dl4jMlpClassifier(); classifier.setNumEpochs(3); KerasEfficientNet kerasEfficientNet = new KerasEfficientNet(); kerasEfficientNet.setVariation(EfficientNet.VARIATION.EFFICIENTNET_B1); classifier.setZooModel(kerasEfficientNet); ImageInstanceIterator iterator = new ImageInstanceIterator(); iterator.setImagesLocation(new File(folderPath)); classifier.setInstanceIterator(iterator); // Stratify and split the data Random rand = new Random(0); inst.randomize(rand); inst.stratify(5); Instances train = inst.trainCV(5, 0); Instances test = inst.testCV(5, 0); // Build the classifier on the training data classifier.buildClassifier(train); // Evaluate the model on test data Evaluation eval = new Evaluation(test); eval.evaluateModel(classifier, test); // Output some summary statistics System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); }
Example 9
Source File: WekaDeeplearning4jExamples.java From wekaDeeplearning4j with GNU General Public License v3.0 | 5 votes |
private static void dl4jResnet50() throws Exception { String folderPath = "src/test/resources/nominal/plant-seedlings-small"; ImageDirectoryLoader loader = new ImageDirectoryLoader(); loader.setInputDirectory(new File(folderPath)); Instances inst = loader.getDataSet(); inst.setClassIndex(1); Dl4jMlpClassifier classifier = new Dl4jMlpClassifier(); classifier.setNumEpochs(3); KerasEfficientNet kerasEfficientNet = new KerasEfficientNet(); kerasEfficientNet.setVariation(EfficientNet.VARIATION.EFFICIENTNET_B1); classifier.setZooModel(kerasEfficientNet); ImageInstanceIterator iterator = new ImageInstanceIterator(); iterator.setImagesLocation(new File(folderPath)); classifier.setInstanceIterator(iterator); // Stratify and split the data Random rand = new Random(0); inst.randomize(rand); inst.stratify(5); Instances train = inst.trainCV(5, 0); Instances test = inst.testCV(5, 0); // Build the classifier on the training data classifier.buildClassifier(train); // Evaluate the model on test data Evaluation eval = new Evaluation(test); eval.evaluateModel(classifier, test); // Output some summary statistics System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); }
Example 10
Source File: LearnShapelets.java From tsml with GNU General Public License v3.0 | 4 votes |
public void buildClassifier(Instances trainData) throws Exception { long startTime=System.currentTimeMillis(); if(paraSearch){ double[] paramsLambdaW; double[] paramsPercentageOfSeriesLength; int[] paramsShapeletLengthScale; paramsLambdaW=lambdaWRange; paramsPercentageOfSeriesLength=percentageOfSeriesLengthRange; paramsShapeletLengthScale=shapeletLengthScaleRange; int noFolds = 2; double bsfAccuracy = 0; int[] params = {0,0,0}; double accuracy = 0; // randomize and stratify the data prior to cross validation trainData.randomize(rand); trainData.stratify(noFolds); int numHpsCombinations=1; for (int i = 0; i < paramsLambdaW.length; i++) { for (int j = 0; j < paramsPercentageOfSeriesLength.length; j++) { for (int k = 0; k < paramsShapeletLengthScale.length; k++) { percentageOfSeriesLength = paramsPercentageOfSeriesLength[j]; R = paramsShapeletLengthScale[k]; lambdaW = paramsLambdaW[i]; print("HPS Combination #"+numHpsCombinations+": {R="+R + ", L="+percentageOfSeriesLength + ", lambdaW="+lambdaW + "}" ); print("--------------------------------------"); double sumAccuracy = 0; //build our test and train sets. for cross-validation. for (int l = 0; l < noFolds; l++) { Instances trainCV = trainData.trainCV(noFolds, l); Instances testCV = trainData.testCV(noFolds, l); // fixed hyper-parameters eta = 0.1; alpha = -30; maxIter=300; print("Learn model for Fold-"+l + ":" ); train(trainCV); //test on the remaining fold. accuracy = utilities.ClassifierTools.accuracy(testCV, this); sumAccuracy += accuracy; print("Accuracy-Fold-"+l + " = " + accuracy ); trainCV=null; testCV=null; } sumAccuracy/=noFolds; print("Accuracy-CV = " + sumAccuracy ); print("--------------------------------------"); if(sumAccuracy > bsfAccuracy){ int[] p = {i,j,k}; params = p; bsfAccuracy = sumAccuracy; } numHpsCombinations++; } } } System.gc(); maxAcc=bsfAccuracy; lambdaW = paramsLambdaW[params[0]]; percentageOfSeriesLength = paramsPercentageOfSeriesLength[params[1]]; R = paramsShapeletLengthScale[params[2]]; eta = 0.1; alpha = -30; maxIter=600; print("Learn final model with best hyper-parameters: R="+R +", L="+percentageOfSeriesLength + ", lambdaW="+lambdaW); } else{ fixParameters(); print("Fixed parameters: R="+R +", L="+percentageOfSeriesLength + ", lambdaW="+lambdaW); } train(trainData); trainResults.setBuildTime(System.currentTimeMillis()-startTime); }
Example 11
Source File: CVParameterSelection.java From tsml with GNU General Public License v3.0 | 4 votes |
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class Instances trainData = new Instances(instances); trainData.deleteWithMissingClass(); if (!(m_Classifier instanceof OptionHandler)) { throw new IllegalArgumentException("Base classifier should be OptionHandler."); } m_InitOptions = ((OptionHandler)m_Classifier).getOptions(); m_BestPerformance = -99; m_NumAttributes = trainData.numAttributes(); Random random = new Random(m_Seed); trainData.randomize(random); m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances(); // Check whether there are any parameters to optimize if (m_CVParams.size() == 0) { m_Classifier.buildClassifier(trainData); m_BestClassifierOptions = m_InitOptions; return; } if (trainData.classAttribute().isNominal()) { trainData.stratify(m_NumFolds); } m_BestClassifierOptions = null; // Set up m_ClassifierOptions -- take getOptions() and remove // those being optimised. m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions(); for (int i = 0; i < m_CVParams.size(); i++) { Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions); } findParamsByCrossValidation(0, trainData, random); String [] options = (String [])m_BestClassifierOptions.clone(); ((OptionHandler)m_Classifier).setOptions(options); m_Classifier.buildClassifier(trainData); }
Example 12
Source File: ConjunctiveRule.java From tsml with GNU General Public License v3.0 | 4 votes |
/** * Builds a single rule learner with REP dealing with nominal classes or * numeric classes. * For nominal classes, this rule learner predicts a distribution on * the classes. * For numeric classes, this learner predicts a single value. * * @param instances the training data * @throws Exception if classifier can't be built successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class Instances data = new Instances(instances); data.deleteWithMissingClass(); if(data.numInstances() < m_Folds) throw new Exception("Not enough data for REP."); m_ClassAttribute = data.classAttribute(); if(m_ClassAttribute.isNominal()) m_NumClasses = m_ClassAttribute.numValues(); else m_NumClasses = 1; m_Antds = new FastVector(); m_DefDstr = new double[m_NumClasses]; m_Cnsqt = new double[m_NumClasses]; m_Targets = new FastVector(); m_Random = new Random(m_Seed); if(m_NumAntds != -1){ grow(data); } else{ data.randomize(m_Random); // Split data into Grow and Prune data.stratify(m_Folds); Instances growData=data.trainCV(m_Folds, m_Folds-1, m_Random); Instances pruneData=data.testCV(m_Folds, m_Folds-1); grow(growData); // Build this rule prune(pruneData); // Prune this rule } if(m_ClassAttribute.isNominal()){ Utils.normalize(m_Cnsqt); if(Utils.gr(Utils.sum(m_DefDstr), 0)) Utils.normalize(m_DefDstr); } }