org.apache.spark.mllib.tree.model.RandomForestModel Java Examples

The following examples show how to use org.apache.spark.mllib.tree.model.RandomForestModel. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: RDFUpdate.java    From oryx with Apache License 2.0 6 votes vote down vote up
/**
 * @param trainPointData data to run down trees
 * @param model random decision forest model to count on
 * @return map of predictor index to the number of training examples that reached a
 *  node whose decision is based on that feature. The index is among predictors, not all
 *  features, since there are fewer predictors than features. That is, the index will
 *  match the one used in the {@link RandomForestModel}.
 */
private static IntLongHashMap predictorExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData,
                                                     RandomForestModel model) {
  return trainPointData.mapPartitions(data -> {
      IntLongHashMap featureIndexCount = new IntLongHashMap();
      data.forEachRemaining(datum -> {
        double[] featureVector = datum.features().toArray();
        for (DecisionTreeModel tree : model.trees()) {
          org.apache.spark.mllib.tree.model.Node node = tree.topNode();
          // This logic cloned from Node.predict:
          while (!node.isLeaf()) {
            Split split = node.split().get();
            int featureIndex = split.feature();
            // Count feature
            featureIndexCount.addToValue(featureIndex, 1);
            node = nextNode(featureVector, node, split, featureIndex);
          }
        }
      });
      return Collections.singleton(featureIndexCount).iterator();
  }).reduce(RDFUpdate::merge);
}
 
Example #2
Source File: RDFUpdate.java    From oryx with Apache License 2.0 5 votes vote down vote up
/**
 * @param trainPointData data to run down trees
 * @param model random decision forest model to count on
 * @return maps of node IDs to the count of training examples that reached that node, one
 *  per tree in the model
 * @see #predictorExampleCounts(JavaRDD,RandomForestModel)
 */
private static List<IntLongHashMap> treeNodeExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData,
                                                          RandomForestModel model) {
  return trainPointData.mapPartitions(data -> {
      DecisionTreeModel[] trees = model.trees();
      List<IntLongHashMap> treeNodeIDCounts = IntStream.range(0, trees.length).
          mapToObj(i -> new IntLongHashMap()).collect(Collectors.toList());
      data.forEachRemaining(datum -> {
        double[] featureVector = datum.features().toArray();
        for (int i = 0; i < trees.length; i++) {
          DecisionTreeModel tree = trees[i];
          IntLongHashMap nodeIDCount = treeNodeIDCounts.get(i);
          org.apache.spark.mllib.tree.model.Node node = tree.topNode();
          // This logic cloned from Node.predict:
          while (!node.isLeaf()) {
            // Count node ID
            nodeIDCount.addToValue(node.id(), 1);
            Split split = node.split().get();
            int featureIndex = split.feature();
            node = nextNode(featureVector, node, split, featureIndex);
          }
          nodeIDCount.addToValue(node.id(), 1);
        }
      });
      return Collections.singleton(treeNodeIDCounts).iterator();
    }
  ).reduce((a, b) -> {
      Preconditions.checkArgument(a.size() == b.size());
      for (int i = 0; i < a.size(); i++) {
        merge(a.get(i), b.get(i));
      }
      return a;
    });
}
 
Example #3
Source File: RandomForestMlib.java    From Java-Data-Science-Cookbook with MIT License 4 votes vote down vote up
public static void main(String args[]){

		SparkConf configuration = new SparkConf().setMaster("local[4]").setAppName("Any");
		JavaSparkContext sc = new JavaSparkContext(configuration);

		// Load and parse the data file.
		String input = "data/rf-data.txt";
		JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), input).toJavaRDD();
		// Split the data into training and test sets (30% held out for testing)
		JavaRDD<LabeledPoint>[] dataSplits = data.randomSplit(new double[]{0.7, 0.3});
		JavaRDD<LabeledPoint> trainingData = dataSplits[0];
		JavaRDD<LabeledPoint> testData = dataSplits[1];

		// Train a RandomForest model.
		Integer numClasses = 2;
		HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();//  Empty categoricalFeaturesInfo indicates all features are continuous.
		Integer numTrees = 3; // Use more in practice.
		String featureSubsetStrategy = "auto"; // Let the algorithm choose.
		String impurity = "gini";
		Integer maxDepth = 5;
		Integer maxBins = 32;
		Integer seed = 12345;

		final RandomForestModel rfModel = RandomForest.trainClassifier(trainingData, numClasses,
				categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
				seed);

		// Evaluate model on test instances and compute test error
		JavaPairRDD<Double, Double> label =
				testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
					public Tuple2<Double, Double> call(LabeledPoint p) {
						return new Tuple2<Double, Double>(rfModel.predict(p.features()), p.label());
					}
				});

		Double testError =
				1.0 * label.filter(new Function<Tuple2<Double, Double>, Boolean>() {
					public Boolean call(Tuple2<Double, Double> pl) {
						return !pl._1().equals(pl._2());
					}
				}).count() / testData.count();

		System.out.println("Test Error: " + testError);
		System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());
	}
 
Example #4
Source File: JavaRandomForestRegressionExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  // $example on$
  SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegressionExample");
  JavaSparkContext jsc = new JavaSparkContext(sparkConf);
  // Load and parse the data file.
  String datapath = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
  // Split the data into training and test sets (30% held out for testing)
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  JavaRDD<LabeledPoint> trainingData = splits[0];
  JavaRDD<LabeledPoint> testData = splits[1];

  // Set parameters.
  // Empty categoricalFeaturesInfo indicates all features are continuous.
  Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
  Integer numTrees = 3; // Use more in practice.
  String featureSubsetStrategy = "auto"; // Let the algorithm choose.
  String impurity = "variance";
  Integer maxDepth = 4;
  Integer maxBins = 32;
  Integer seed = 12345;
  // Train a RandomForest model.
  final RandomForestModel model = RandomForest.trainRegressor(trainingData,
    categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed);

  // Evaluate model on test instances and compute test error
  JavaPairRDD<Double, Double> predictionAndLabel =
    testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
  Double testMSE =
    predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
      @Override
      public Double call(Tuple2<Double, Double> pl) {
        Double diff = pl._1() - pl._2();
        return diff * diff;
      }
    }).reduce(new Function2<Double, Double, Double>() {
      @Override
      public Double call(Double a, Double b) {
        return a + b;
      }
    }) / testData.count();
  System.out.println("Test Mean Squared Error: " + testMSE);
  System.out.println("Learned regression forest model:\n" + model.toDebugString());

  // Save and load model
  model.save(jsc.sc(), "target/tmp/myRandomForestRegressionModel");
  RandomForestModel sameModel = RandomForestModel.load(jsc.sc(),
    "target/tmp/myRandomForestRegressionModel");
  // $example off$

  jsc.stop();
}
 
Example #5
Source File: JavaRandomForestClassificationExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  // $example on$
  SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample");
  JavaSparkContext jsc = new JavaSparkContext(sparkConf);
  // Load and parse the data file.
  String datapath = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
  // Split the data into training and test sets (30% held out for testing)
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  JavaRDD<LabeledPoint> trainingData = splits[0];
  JavaRDD<LabeledPoint> testData = splits[1];

  // Train a RandomForest model.
  // Empty categoricalFeaturesInfo indicates all features are continuous.
  Integer numClasses = 2;
  HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
  Integer numTrees = 3; // Use more in practice.
  String featureSubsetStrategy = "auto"; // Let the algorithm choose.
  String impurity = "gini";
  Integer maxDepth = 5;
  Integer maxBins = 32;
  Integer seed = 12345;

  final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
    categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
    seed);

  // Evaluate model on test instances and compute test error
  JavaPairRDD<Double, Double> predictionAndLabel =
    testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
  Double testErr =
    1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
      @Override
      public Boolean call(Tuple2<Double, Double> pl) {
        return !pl._1().equals(pl._2());
      }
    }).count() / testData.count();
  System.out.println("Test Error: " + testErr);
  System.out.println("Learned classification forest model:\n" + model.toDebugString());

  // Save and load model
  model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel");
  RandomForestModel sameModel = RandomForestModel.load(jsc.sc(),
    "target/tmp/myRandomForestClassificationModel");
  // $example off$

  jsc.stop();
}
 
Example #6
Source File: RDFUpdate.java    From oryx with Apache License 2.0 4 votes vote down vote up
@Override
public PMML buildModel(JavaSparkContext sparkContext,
                       JavaRDD<String> trainData,
                       List<?> hyperParameters,
                       Path candidatePath) {

  int maxSplitCandidates = (Integer) hyperParameters.get(0);
  int maxDepth = (Integer) hyperParameters.get(1);
  String impurity = (String) hyperParameters.get(2);
  Preconditions.checkArgument(maxSplitCandidates >= 2,
                              "max-split-candidates must be at least 2");
  Preconditions.checkArgument(maxDepth > 0,
                              "max-depth must be at least 1");

  JavaRDD<String[]> parsedRDD = trainData.map(MLFunctions.PARSE_FN);
  CategoricalValueEncodings categoricalValueEncodings =
      new CategoricalValueEncodings(getDistinctValues(parsedRDD));
  JavaRDD<LabeledPoint> trainPointData =
      parseToLabeledPointRDD(parsedRDD, categoricalValueEncodings);

  Map<Integer,Integer> categoryInfo = categoricalValueEncodings.getCategoryCounts();
  categoryInfo.remove(inputSchema.getTargetFeatureIndex()); // Don't specify target count
  // Need to translate indices to predictor indices
  Map<Integer,Integer> categoryInfoByPredictor = new HashMap<>(categoryInfo.size());
  categoryInfo.forEach((k, v) -> categoryInfoByPredictor.put(inputSchema.featureToPredictorIndex(k), v));

  int seed = RandomManager.getRandom().nextInt();

  RandomForestModel model;
  if (inputSchema.isClassification()) {
    int numTargetClasses =
        categoricalValueEncodings.getValueCount(inputSchema.getTargetFeatureIndex());
    model = RandomForest.trainClassifier(trainPointData,
                                         numTargetClasses,
                                         categoryInfoByPredictor,
                                         numTrees,
                                         "auto",
                                         impurity,
                                         maxDepth,
                                         maxSplitCandidates,
                                         seed);
  } else {
    model = RandomForest.trainRegressor(trainPointData,
                                        categoryInfoByPredictor,
                                        numTrees,
                                        "auto",
                                        impurity,
                                        maxDepth,
                                        maxSplitCandidates,
                                        seed);
  }

  List<IntLongHashMap> treeNodeIDCounts = treeNodeExampleCounts(trainPointData, model);
  IntLongHashMap predictorIndexCounts = predictorExampleCounts(trainPointData, model);

  return rdfModelToPMML(model,
                        categoricalValueEncodings,
                        maxDepth,
                        maxSplitCandidates,
                        impurity,
                        treeNodeIDCounts,
                        predictorIndexCounts);
}
 
Example #7
Source File: RDFUpdate.java    From oryx with Apache License 2.0 4 votes vote down vote up
private PMML rdfModelToPMML(RandomForestModel rfModel,
                            CategoricalValueEncodings categoricalValueEncodings,
                            int maxDepth,
                            int maxSplitCandidates,
                            String impurity,
                            List<? extends IntLongMap> nodeIDCounts,
                            IntLongMap predictorIndexCounts) {

  boolean classificationTask = rfModel.algo().equals(Algo.Classification());
  Preconditions.checkState(classificationTask == inputSchema.isClassification());

  DecisionTreeModel[] trees = rfModel.trees();

  Model model;
  if (trees.length == 1) {
    model = toTreeModel(trees[0], categoricalValueEncodings, nodeIDCounts.get(0));
  } else {
    MiningModel miningModel = new MiningModel();
    model = miningModel;
    Segmentation.MultipleModelMethod multipleModelMethodType = classificationTask ?
        Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE :
        Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE;
    List<Segment> segments = new ArrayList<>(trees.length);
    for (int treeID = 0; treeID < trees.length; treeID++) {
      TreeModel treeModel =
          toTreeModel(trees[treeID], categoricalValueEncodings, nodeIDCounts.get(treeID));
      segments.add(new Segment()
           .setId(Integer.toString(treeID))
           .setPredicate(new True())
           .setModel(treeModel)
           .setWeight(1.0)); // No weights in MLlib impl now
    }
    miningModel.setSegmentation(new Segmentation(multipleModelMethodType, segments));
  }

  model.setMiningFunction(classificationTask ?
                          MiningFunction.CLASSIFICATION :
                          MiningFunction.REGRESSION);

  double[] importances = countsToImportances(predictorIndexCounts);
  model.setMiningSchema(AppPMMLUtils.buildMiningSchema(inputSchema, importances));
  DataDictionary dictionary =
      AppPMMLUtils.buildDataDictionary(inputSchema, categoricalValueEncodings);

  PMML pmml = PMMLUtils.buildSkeletonPMML();
  pmml.setDataDictionary(dictionary);
  pmml.addModels(model);

  AppPMMLUtils.addExtension(pmml, "maxDepth", maxDepth);
  AppPMMLUtils.addExtension(pmml, "maxSplitCandidates", maxSplitCandidates);
  AppPMMLUtils.addExtension(pmml, "impurity", impurity);

  return pmml;
}