smile.data.NominalAttribute Java Examples
The following examples show how to use
smile.data.NominalAttribute.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SmileRandomForest.java From kogito-runtimes with Apache License 2.0 | 5 votes |
protected Attribute createAttribute(String name, AttributeType type) { if (type == AttributeType.NOMINAL || type == AttributeType.BOOLEAN) { return new NominalAttribute(name); } else if (type == AttributeType.NUMERIC) { return new NumericAttribute(name); } else { return new StringAttribute(name); } }
Example #2
Source File: LinkClassifierTrainer.java From ache with Apache License 2.0 | 5 votes |
/** * Converts the input instances into an AttributeDataset object that can be used to train a * SMILE classifier. * * @param attributes * @param instances * @param wrapper * @param dataset * @throws IOException */ private AttributeDataset createDataset(List<Sampler<LinkNeighborhood>> instances, String[] features, List<String> classValues, LinkNeighborhoodWrapper wrapper) { List<Attribute> attributes = new ArrayList<>(); for(String featureName : features) { NumericAttribute attribute = new NumericAttribute(featureName); attributes.add(attribute); } Attribute[] attributesArray = (Attribute[]) attributes.toArray(new Attribute[attributes.size()]); String[] classValuesArray = (String[]) classValues.toArray(new String[classValues.size()]); String description = "If link leads to relevant page or not."; Attribute response = new NominalAttribute("y", description, classValuesArray); AttributeDataset dataset = new AttributeDataset("link_classifier", attributesArray, response); for (int level = 0; level < instances.size(); level++) { Sampler<LinkNeighborhood> levelSamples = instances.get(level); for (LinkNeighborhood ln : levelSamples.getSamples()) { Instance instance; try { instance = wrapper.extractToInstance(ln, features); } catch (MalformedURLException e) { logger.warn("Failed to process intance: "+ln.getLink().toString(), e); continue; } double[] values = instance.getValues(); // the instance's feature vector int y = level; // the class we're trying to predict dataset.add(values, y); } } return dataset; }
Example #3
Source File: DecisionTreeTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testTitanicPruning() throws IOException, ParseException { String datasetUrl = "https://gist.githubusercontent.com/myui/7cd82c443db84ba7e7add1523d0247a9/raw/f2d3e3051b0292577e8c01a1759edabaa95c5781/titanic_train.tsv"; URL url = new URL(datasetUrl); InputStream is = new BufferedInputStream(url.openStream()); DelimitedTextParser parser = new DelimitedTextParser(); parser.setColumnNames(true); parser.setDelimiter(","); parser.setResponseIndex(new NominalAttribute("survived"), 0); AttributeDataset train = parser.parse("titanic train", is); double[][] x_ = train.toArray(new double[train.size()][]); int[] y = train.toArray(new int[train.size()]); // pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked // C,C,C,Q,Q,Q,C,Q,C,C RoaringBitmap nominalAttrs = new RoaringBitmap(); nominalAttrs.add(0); nominalAttrs.add(1); nominalAttrs.add(2); nominalAttrs.add(6); nominalAttrs.add(8); nominalAttrs.add(9); int columns = x_[0].length; Matrix x = new RowMajorDenseMatrix2d(x_, columns); int numVars = (int) Math.ceil(Math.sqrt(columns)); int maxDepth = Integer.MAX_VALUE; int maxLeafs = Integer.MAX_VALUE; int minSplits = 2; int minLeafSize = 1; int[] samples = null; PRNG rand = RandomNumberGeneratorFactory.createPRNG(43L); final String[] featureNames = new String[] {"pclass", "name", "sex", "age", "sibsp", "parch", "ticket", "fare", "cabin", "embarked"}; final String[] classNames = new String[] {"yes", "no"}; DecisionTree tree = new DecisionTree(nominalAttrs, x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize, samples, SplitRule.GINI, rand) { @Override public String toString() { return predictJsCodegen(featureNames, classNames); } }; tree.toString(); }