smile.data.AttributeDataset Java Examples
The following examples show how to use
smile.data.AttributeDataset.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TreePredictUDFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
/** * Test of learn method, of class DecisionTree. */ @Test public void testIris() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4); Assert.assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]])); } }
Example #2
Source File: LinkClassifierTrainer.java From ache with Apache License 2.0 | 6 votes |
private SoftClassifier<double[]> trainClassifier(AttributeDataset data, int numberOfClasses) { final double c = 1.0; SVM<double[]> classifier; if(numberOfClasses > 2) { classifier = new SVM<double[]>(new LinearKernel(), c, numberOfClasses, Multiclass.ONE_VS_ALL); } else { classifier = new SVM<double[]>(new LinearKernel(), c); } int[] y = data.labels(); double[][] x = data.x(); classifier.learn(x, y); classifier.finish(); classifier.trainPlattScaling(x, y); return classifier; }
Example #3
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testIrisSparseDenseEquals() throws IOException, ParseException, HiveException { String urlString = "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"; DecisionTree.Node denseNode = getDecisionTreeFromDenseInput(urlString); DecisionTree.Node sparseNode = getDecisionTreeFromSparseInput(urlString); URL url = new URL(urlString); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int diff = 0; for (int i = 0; i < size; i++) { if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) { diff++; } } Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10); }
Example #4
Source File: TreePredictUDFv1Test.java From incubator-hivemall with Apache License 2.0 | 6 votes |
/** * Test of learn method, of class DecisionTree. */ @Test public void testIris() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4); assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]])); } }
Example #5
Source File: DecisionTreeTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
private static String graphvizOutput(String datasetUrl, int responseIndex, int numLeafs, boolean dense, String[] featureNames, String[] classNames, String outputName) throws IOException, HiveException, ParseException { URL url = new URL(datasetUrl); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(responseIndex); AttributeDataset ds = arffParser.parse(is); double[][] x = ds.toArray(new double[ds.size()][]); int[] y = ds.toArray(new int[ds.size()]); RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes()); DecisionTree tree = new DecisionTree(attrs, matrix(x, dense), y, numLeafs, RandomNumberGeneratorFactory.createPRNG(31)); Text model = new Text(Base91.encode(tree.serialize(true))); Evaluator eval = new Evaluator(OutputType.graphviz, outputName, false); Text exported = eval.export(model, featureNames, classNames); return exported.toString(); }
Example #6
Source File: SmileRandomForest.java From kogito-runtimes with Apache License 2.0 | 6 votes |
public SmileRandomForest(Map<String, AttributeType> inputFeatures, String outputFeatureName, AttributeType outputFeatureType, double confidenceThreshold, int numberTrees) { super(inputFeatures, outputFeatureName, outputFeatureType, confidenceThreshold); this.numberTrees = numberTrees; smileAttributes = new HashMap<>(); for (Entry<String, AttributeType> inputFeature : inputFeatures.entrySet()) { final String name = inputFeature.getKey(); final AttributeType type = inputFeature.getValue(); smileAttributes.put(name, createAttribute(name, type)); attributeNames.add(name); } numAttributes = smileAttributes.size(); outcomeAttribute = createAttribute(outputFeatureName, outputFeatureType); outcomeAttributeType = outputFeatureType; dataset = new AttributeDataset("dataset", smileAttributes.values().toArray(new Attribute[numAttributes]), outcomeAttribute); }
Example #7
Source File: DecisionTreeTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testIrisSerializedObj() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4); byte[] b = tree.serialize(false); Node node = DecisionTree.deserialize(b, b.length, false); assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]])); } }
Example #8
Source File: DecisionTreeTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
private static void runAndCompareSparseAndDense(String datasetUrl, int responseIndex, int numLeafs) throws IOException, ParseException { URL url = new URL(datasetUrl); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(responseIndex); AttributeDataset ds = arffParser.parse(is); double[][] x = ds.toArray(new double[ds.size()][]); int[] y = ds.toArray(new int[ds.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes()); DecisionTree dtree = new DecisionTree(attrs, matrix(trainx, true), trainy, numLeafs, RandomNumberGeneratorFactory.createPRNG(i)); DecisionTree stree = new DecisionTree(attrs, matrix(trainx, false), trainy, numLeafs, RandomNumberGeneratorFactory.createPRNG(i)); Assert.assertEquals(dtree.predict(x[loocv.test[i]]), stree.predict(x[loocv.test[i]])); Assert.assertEquals(dtree.toString(), stree.toString()); } }
Example #9
Source File: TreePredictUDFv1Test.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void testCpu2() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(6); AttributeDataset data = arffParser.parse(is); double[] datay = data.toArray(new double[data.size()]); double[][] datax = data.toArray(new double[data.size()][]); int n = datax.length; int m = 3 * n / 4; int[] index = Math.permutate(n); double[][] trainx = new double[m][]; double[] trainy = new double[m]; for (int i = 0; i < m; i++) { trainx[i] = datax[index[i]]; trainy[i] = datay[index[i]]; } double[][] testx = new double[n - m][]; double[] testy = new double[n - m]; for (int i = m; i < n; i++) { testx[i - m] = datax[index[i]]; testy[i - m] = datay[index[i]]; } RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20); debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy))); for (int i = m; i < n; i++) { assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0); } }
Example #10
Source File: TreePredictUDFv1Test.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void testCpu() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(6); AttributeDataset data = arffParser.parse(is); double[] datay = data.toArray(new double[data.size()]); double[][] datax = data.toArray(new double[data.size()][]); int n = datax.length; int k = 10; CrossValidation cv = new CrossValidation(n, k); for (int i = 0; i < k; i++) { double[][] trainx = Math.slice(datax, cv.train[i]); double[] trainy = Math.slice(datay, cv.train[i]); double[][] testx = Math.slice(datax, cv.test[i]); RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20); for (int j = 0; j < testx.length; j++) { assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0); } } }
Example #11
Source File: GradientTreeBoostingClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void testSerialization() throws HiveException, IOException, ParseException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); final Object[][] rows = new Object[size][2]; for (int i = 0; i < size; i++) { double[] row = x[i]; final List<String> xi = new ArrayList<String>(x[0].length); for (int j = 0; j < row.length; j++) { xi.add(j + ":" + row[j]); } rows[i][0] = xi; rows[i][1] = y[i]; } TestUtils.testGenericUDTFSerialization(GradientTreeBoostingClassifierUDTF.class, new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490")}, rows); }
Example #12
Source File: TreePredictUDFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void testCpu() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(6); AttributeDataset data = arffParser.parse(is); double[] datay = data.toArray(new double[data.size()]); double[][] datax = data.toArray(new double[data.size()][]); int n = datax.length; int k = 10; CrossValidation cv = new CrossValidation(n, k); for (int i = 0; i < k; i++) { double[][] trainx = Math.slice(datax, cv.train[i]); double[] trainy = Math.slice(datay, cv.train[i]); double[][] testx = Math.slice(datax, cv.test[i]); RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20); for (int j = 0; j < testx.length; j++) { Assert.assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0); } } }
Example #13
Source File: TreePredictUDFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void testCpu2() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(6); AttributeDataset data = arffParser.parse(is); double[] datay = data.toArray(new double[data.size()]); double[][] datax = data.toArray(new double[data.size()][]); int n = datax.length; int m = 3 * n / 4; int[] index = Math.permutate(n); double[][] trainx = new double[m][]; double[] trainy = new double[m]; for (int i = 0; i < m; i++) { trainx[i] = datax[index[i]]; trainy[i] = datay[index[i]]; } double[][] testx = new double[n - m][]; double[] testy = new double[n - m]; for (int i = m; i < n; i++) { testx[i - m] = datax[index[i]]; testy[i - m] = datay[index[i]]; } RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20); debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy))); for (int i = m; i < n; i++) { Assert.assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0); } }
Example #14
Source File: DecisionTreeTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void testIrisSerializeObjCompressed() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4); byte[] b1 = tree.serialize(true); byte[] b2 = tree.serialize(false); Assert.assertTrue("b1.length = " + b1.length + ", b2.length = " + b2.length, b1.length < b2.length); Node node = DecisionTree.deserialize(b1, b1.length, true); assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]])); } }
Example #15
Source File: DecisionTreeTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
private static int run(String datasetUrl, int responseIndex, int numLeafs, boolean dense) throws IOException, ParseException { URL url = new URL(datasetUrl); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(responseIndex); AttributeDataset ds = arffParser.parse(is); double[][] x = ds.toArray(new double[ds.size()][]); int[] y = ds.toArray(new int[ds.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); int error = 0; for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes()); DecisionTree tree = new DecisionTree(attrs, matrix(trainx, dense), trainy, numLeafs, RandomNumberGeneratorFactory.createPRNG(i)); if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) { error++; } } debugPrint("Decision Tree error = " + error); return error; }
Example #16
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void testSerialization() throws HiveException, IOException, ParseException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); final Object[][] rows = new Object[size][2]; for (int i = 0; i < size; i++) { double[] row = x[i]; final List<String> xi = new ArrayList<String>(x[0].length); for (int j = 0; j < row.length; j++) { xi.add(j + ":" + row[j]); } rows[i][0] = xi; rows[i][1] = y[i]; } TestUtils.testGenericUDTFSerialization(RandomForestClassifierUDTF.class, new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49")}, rows); }
Example #17
Source File: Test.java From java_in_examples with Apache License 2.0 | 5 votes |
private void test() throws IOException, ParseException { ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset weather = arffParser.parse(this.getClass().getResourceAsStream("/smile/data/weka/weather.nominal.arff")); double[][] x = weather.toArray(new double[weather.size()][]); int[] y = weather.toArray(new int[weather.size()]); }
Example #18
Source File: LinkClassifierTrainer.java From ache with Apache License 2.0 | 5 votes |
private LNClassifier createLNClassifier(List<Sampler<LinkNeighborhood>> instances, List<String> classValues, Features bestFeatures) { AttributeDataset inputDataset = createSmileInput(instances, bestFeatures); SoftClassifier<double[]> classifier = trainClassifier(inputDataset, instances.size()); String[] classValuesArray = (String[]) classValues.toArray(new String[classValues.size()]); LinkNeighborhoodWrapper wrapper = new LinkNeighborhoodWrapper(bestFeatures.features, stoplist); return new LNClassifier(classifier, wrapper, bestFeatures.features, classValuesArray); }
Example #19
Source File: LinkClassifierTrainer.java From ache with Apache License 2.0 | 5 votes |
private AttributeDataset createSmileInput(List<Sampler<LinkNeighborhood>> instances, Features bestFeatures) { LinkNeighborhoodWrapper wrapper = new LinkNeighborhoodWrapper(this.stoplist); wrapper.setFeatures(bestFeatures.fieldWords); List<String> classValues = new ArrayList<>(); for (int i = 0; i < instances.size(); i++) { classValues.add(String.valueOf(i)); } return createDataset(instances, bestFeatures.features, classValues, wrapper); }
Example #20
Source File: LinkClassifierTrainer.java From ache with Apache License 2.0 | 5 votes |
/** * Converts the input instances into an AttributeDataset object that can be used to train a * SMILE classifier. * * @param attributes * @param instances * @param wrapper * @param dataset * @throws IOException */ private AttributeDataset createDataset(List<Sampler<LinkNeighborhood>> instances, String[] features, List<String> classValues, LinkNeighborhoodWrapper wrapper) { List<Attribute> attributes = new ArrayList<>(); for(String featureName : features) { NumericAttribute attribute = new NumericAttribute(featureName); attributes.add(attribute); } Attribute[] attributesArray = (Attribute[]) attributes.toArray(new Attribute[attributes.size()]); String[] classValuesArray = (String[]) classValues.toArray(new String[classValues.size()]); String description = "If link leads to relevant page or not."; Attribute response = new NominalAttribute("y", description, classValuesArray); AttributeDataset dataset = new AttributeDataset("link_classifier", attributesArray, response); for (int level = 0; level < instances.size(); level++) { Sampler<LinkNeighborhood> levelSamples = instances.get(level); for (LinkNeighborhood ln : levelSamples.getSamples()) { Instance instance; try { instance = wrapper.extractToInstance(ln, features); } catch (MalformedURLException e) { logger.warn("Failed to process intance: "+ln.getLink().toString(), e); continue; } double[] values = instance.getValues(); // the instance's feature vector int y = level; // the class we're trying to predict dataset.add(values, y); } } return dataset; }
Example #21
Source File: SmileTargetClassifierBuilder.java From ache with Apache License 2.0 | 5 votes |
public static void trainModel(String trainingPath, String outputPath, String learner, int responseIndex, boolean skipCrossValidation) throws Exception { if (learner == null) { learner = "SVM"; } System.out.println("Learning algorithm: " + learner); String modelFilePath = Paths.get(outputPath, "pageclassifier.model").toString(); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(responseIndex); Path arffFilePath = Paths.get(trainingPath, "/smile_input.arff"); FileInputStream fis = new FileInputStream(arffFilePath.toFile()); System.out.println("Writting temporarily data file to: " + arffFilePath.toString()); AttributeDataset trainingData = arffParser.parse(fis); double[][] x = trainingData.toArray(new double[trainingData.size()][]); int[] y = trainingData.toArray(new int[trainingData.size()]); SoftClassifier<double[]> finalModel = null; if (skipCrossValidation) { System.out.println("Starting model training on whole dataset..."); finalModel = trainClassifierNoCV(learner, x, y); } else { System.out.println("Starting cross-validation..."); finalModel = trainModelCV(learner, x, y); } System.out.println("Writing model to file: " + modelFilePath); SmileUtil.writeSmileClassifier(modelFilePath, finalModel); }
Example #22
Source File: TreePredictUDFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testSerialization() throws HiveException, IOException, ParseException { URL url = new URL( "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(6); AttributeDataset data = arffParser.parse(is); double[] datay = data.toArray(new double[data.size()]); double[][] datax = data.toArray(new double[data.size()][]); int n = datax.length; int m = 3 * n / 4; int[] index = Math.permutate(n); double[][] trainx = new double[m][]; double[] trainy = new double[m]; for (int i = 0; i < m; i++) { trainx[i] = datax[index[i]]; trainy[i] = datay[index[i]]; } double[][] testx = new double[n - m][]; double[] testy = new double[n - m]; for (int i = m; i < n; i++) { testx[i - m] = datax[index[i]]; testy[i - m] = datay[index[i]]; } RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20); byte[] b = tree.serialize(true); byte[] encoded = Base91.encode(b); Text model = new Text(encoded); TestUtils.testGenericUDFSerialization(TreePredictUDF.class, new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, PrimitiveObjectInspectorFactory.writableStringObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)}, new Object[] {"model_id#1", model, ArrayUtils.toList(testx[0])}); }
Example #23
Source File: GradientTreeBoostingClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testIrisDense() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); final List<Double> xi = new ArrayList<Double>(x[0].length); for (int i = 0; i < size; i++) { for (int j = 0; j < x[i].length; j++) { xi.add(j, x[i][j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final MutableInt count = new MutableInt(0); Collector collector = new Collector() { public void collect(Object input) throws HiveException { count.addValue(1); } }; udtf.setCollector(collector); udtf.close(); Assert.assertEquals(490, count.getValue()); }
Example #24
Source File: TreePredictUDFv1Test.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testSerialization() throws HiveException, IOException, ParseException { URL url = new URL( "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(6); AttributeDataset data = arffParser.parse(is); double[] datay = data.toArray(new double[data.size()]); double[][] datax = data.toArray(new double[data.size()][]); int n = datax.length; int m = 3 * n / 4; int[] index = Math.permutate(n); double[][] trainx = new double[m][]; double[] trainy = new double[m]; for (int i = 0; i < m; i++) { trainx[i] = datax[index[i]]; trainy[i] = datay[index[i]]; } double[][] testx = new double[n - m][]; double[] testy = new double[n - m]; for (int i = m; i < n; i++) { testx[i - m] = datax[index[i]]; testy[i - m] = datay[index[i]]; } RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20); String opScript = tree.predictOpCodegen(StackMachine.SEP); TestUtils.testGenericUDFSerialization(TreePredictUDFv1.class, new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaStringObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)}, new Object[] {"model_id#1", ModelType.opscode.getId(), opScript, ArrayUtils.toList(testx[0])}); }
Example #25
Source File: GradientTreeBoostingClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testIrisSparse() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); final List<String> xi = new ArrayList<String>(x[0].length); for (int i = 0; i < size; i++) { double[] row = x[i]; for (int j = 0; j < row.length; j++) { xi.add(j + ":" + row[j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final MutableInt count = new MutableInt(0); Collector collector = new Collector() { public void collect(Object input) throws HiveException { count.addValue(1); } }; udtf.setCollector(collector); udtf.close(); Assert.assertEquals(490, count.getValue()); }
Example #26
Source File: DecisionTreeTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testTitanicPruning() throws IOException, ParseException { String datasetUrl = "https://gist.githubusercontent.com/myui/7cd82c443db84ba7e7add1523d0247a9/raw/f2d3e3051b0292577e8c01a1759edabaa95c5781/titanic_train.tsv"; URL url = new URL(datasetUrl); InputStream is = new BufferedInputStream(url.openStream()); DelimitedTextParser parser = new DelimitedTextParser(); parser.setColumnNames(true); parser.setDelimiter(","); parser.setResponseIndex(new NominalAttribute("survived"), 0); AttributeDataset train = parser.parse("titanic train", is); double[][] x_ = train.toArray(new double[train.size()][]); int[] y = train.toArray(new int[train.size()]); // pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked // C,C,C,Q,Q,Q,C,Q,C,C RoaringBitmap nominalAttrs = new RoaringBitmap(); nominalAttrs.add(0); nominalAttrs.add(1); nominalAttrs.add(2); nominalAttrs.add(6); nominalAttrs.add(8); nominalAttrs.add(9); int columns = x_[0].length; Matrix x = new RowMajorDenseMatrix2d(x_, columns); int numVars = (int) Math.ceil(Math.sqrt(columns)); int maxDepth = Integer.MAX_VALUE; int maxLeafs = Integer.MAX_VALUE; int minSplits = 2; int minLeafSize = 1; int[] samples = null; PRNG rand = RandomNumberGeneratorFactory.createPRNG(43L); final String[] featureNames = new String[] {"pclass", "name", "sex", "age", "sibsp", "parch", "ticket", "fare", "cabin", "embarked"}; final String[] classNames = new String[] {"yes", "no"}; DecisionTree tree = new DecisionTree(nominalAttrs, x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize, samples, SplitRule.GINI, rand) { @Override public String toString() { return predictJsCodegen(featureNames, classNames); } }; tree.toString(); }
Example #27
Source File: DecisionTreeTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
private static void runTracePredict(String datasetUrl, int responseIndex, int numLeafs) throws IOException, ParseException { URL url = new URL(datasetUrl); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(responseIndex); AttributeDataset ds = arffParser.parse(is); final Attribute[] attrs = ds.attributes(); final Attribute targetAttr = ds.response(); double[][] x = ds.toArray(new double[ds.size()][]); int[] y = ds.toArray(new int[ds.size()]); Random rnd = new Random(43L); int numTrain = (int) (x.length * 0.7); int[] index = ArrayUtils.shuffle(MathUtils.permutation(x.length), rnd); int[] cvTrain = Arrays.copyOf(index, numTrain); int[] cvTest = Arrays.copyOfRange(index, numTrain, index.length); double[][] trainx = Math.slice(x, cvTrain); int[] trainy = Math.slice(y, cvTrain); double[][] testx = Math.slice(x, cvTest); DecisionTree tree = new DecisionTree(SmileExtUtils.convertAttributeTypes(attrs), matrix(trainx, false), trainy, numLeafs, RandomNumberGeneratorFactory.createPRNG(43L)); final LinkedHashMap<String, Double> map = new LinkedHashMap<>(); final StringBuilder buf = new StringBuilder(); for (int i = 0; i < testx.length; i++) { final DenseVector test = new DenseVector(testx[i]); tree.predict(test, new PredictionHandler() { @Override public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature, double splitValue) { buf.append(attrs[splitFeatureIndex].name); buf.append(" [" + splitFeature + "] "); buf.append(op); buf.append(' '); buf.append(splitValue); buf.append('\n'); map.put(attrs[splitFeatureIndex].name + " [" + splitFeature + "] " + op, splitValue); } @Override public void visitLeaf(int output, double[] posteriori) { buf.append(targetAttr.toString(output)); } }); Assert.assertTrue(buf.length() > 0); Assert.assertFalse(map.isEmpty()); StringUtils.clear(buf); map.clear(); } }
Example #28
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
private static DecisionTree.Node getDecisionTreeFromSparseInput(String urlString) throws IOException, ParseException, HiveException { URL url = new URL(urlString); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); final List<String> xi = new ArrayList<String>(x[0].length); for (int i = 0; i < size; i++) { final double[] row = x[i]; for (int j = 0; j < row.length; j++) { xi.add(j + ":" + row[j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final Text[] placeholder = new Text[1]; Collector collector = new Collector() { public void collect(Object input) throws HiveException { Object[] forward = (Object[]) input; placeholder[0] = (Text) forward[2]; } }; udtf.setCollector(collector); udtf.close(); Text modelTxt = placeholder[0]; Assert.assertNotNull(modelTxt); byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength()); DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true); return node; }
Example #29
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
private static DecisionTree.Node getDecisionTreeFromDenseInput(String urlString) throws IOException, ParseException, HiveException { URL url = new URL(urlString); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); final List<Double> xi = new ArrayList<Double>(x[0].length); for (int i = 0; i < size; i++) { for (int j = 0; j < x[i].length; j++) { xi.add(j, x[i][j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final Text[] placeholder = new Text[1]; Collector collector = new Collector() { public void collect(Object input) throws HiveException { Object[] forward = (Object[]) input; placeholder[0] = (Text) forward[2]; } }; udtf.setCollector(collector); udtf.close(); Text modelTxt = placeholder[0]; Assert.assertNotNull(modelTxt); byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength()); DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true); return node; }
Example #30
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testIrisSparse() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); final List<String> xi = new ArrayList<String>(x[0].length); for (int i = 0; i < size; i++) { double[] row = x[i]; for (int j = 0; j < row.length; j++) { xi.add(j + ":" + row[j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final MutableInt count = new MutableInt(0); Collector collector = new Collector() { public void collect(Object input) throws HiveException { count.addValue(1); } }; udtf.setCollector(collector); udtf.close(); Assert.assertEquals(49, count.getValue()); }