org.apache.spark.mllib.tree.RandomForest Java Examples

The following examples show how to use org.apache.spark.mllib.tree.RandomForest. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: RandomForestMlib.java    From Java-Data-Science-Cookbook with MIT License 4 votes vote down vote up
public static void main(String args[]){

		SparkConf configuration = new SparkConf().setMaster("local[4]").setAppName("Any");
		JavaSparkContext sc = new JavaSparkContext(configuration);

		// Load and parse the data file.
		String input = "data/rf-data.txt";
		JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), input).toJavaRDD();
		// Split the data into training and test sets (30% held out for testing)
		JavaRDD<LabeledPoint>[] dataSplits = data.randomSplit(new double[]{0.7, 0.3});
		JavaRDD<LabeledPoint> trainingData = dataSplits[0];
		JavaRDD<LabeledPoint> testData = dataSplits[1];

		// Train a RandomForest model.
		Integer numClasses = 2;
		HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();//  Empty categoricalFeaturesInfo indicates all features are continuous.
		Integer numTrees = 3; // Use more in practice.
		String featureSubsetStrategy = "auto"; // Let the algorithm choose.
		String impurity = "gini";
		Integer maxDepth = 5;
		Integer maxBins = 32;
		Integer seed = 12345;

		final RandomForestModel rfModel = RandomForest.trainClassifier(trainingData, numClasses,
				categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
				seed);

		// Evaluate model on test instances and compute test error
		JavaPairRDD<Double, Double> label =
				testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
					public Tuple2<Double, Double> call(LabeledPoint p) {
						return new Tuple2<Double, Double>(rfModel.predict(p.features()), p.label());
					}
				});

		Double testError =
				1.0 * label.filter(new Function<Tuple2<Double, Double>, Boolean>() {
					public Boolean call(Tuple2<Double, Double> pl) {
						return !pl._1().equals(pl._2());
					}
				}).count() / testData.count();

		System.out.println("Test Error: " + testError);
		System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());
	}
 
Example #2
Source File: JavaRandomForestRegressionExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  // $example on$
  SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegressionExample");
  JavaSparkContext jsc = new JavaSparkContext(sparkConf);
  // Load and parse the data file.
  String datapath = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
  // Split the data into training and test sets (30% held out for testing)
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  JavaRDD<LabeledPoint> trainingData = splits[0];
  JavaRDD<LabeledPoint> testData = splits[1];

  // Set parameters.
  // Empty categoricalFeaturesInfo indicates all features are continuous.
  Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
  Integer numTrees = 3; // Use more in practice.
  String featureSubsetStrategy = "auto"; // Let the algorithm choose.
  String impurity = "variance";
  Integer maxDepth = 4;
  Integer maxBins = 32;
  Integer seed = 12345;
  // Train a RandomForest model.
  final RandomForestModel model = RandomForest.trainRegressor(trainingData,
    categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed);

  // Evaluate model on test instances and compute test error
  JavaPairRDD<Double, Double> predictionAndLabel =
    testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
  Double testMSE =
    predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
      @Override
      public Double call(Tuple2<Double, Double> pl) {
        Double diff = pl._1() - pl._2();
        return diff * diff;
      }
    }).reduce(new Function2<Double, Double, Double>() {
      @Override
      public Double call(Double a, Double b) {
        return a + b;
      }
    }) / testData.count();
  System.out.println("Test Mean Squared Error: " + testMSE);
  System.out.println("Learned regression forest model:\n" + model.toDebugString());

  // Save and load model
  model.save(jsc.sc(), "target/tmp/myRandomForestRegressionModel");
  RandomForestModel sameModel = RandomForestModel.load(jsc.sc(),
    "target/tmp/myRandomForestRegressionModel");
  // $example off$

  jsc.stop();
}
 
Example #3
Source File: JavaRandomForestClassificationExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  // $example on$
  SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample");
  JavaSparkContext jsc = new JavaSparkContext(sparkConf);
  // Load and parse the data file.
  String datapath = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
  // Split the data into training and test sets (30% held out for testing)
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  JavaRDD<LabeledPoint> trainingData = splits[0];
  JavaRDD<LabeledPoint> testData = splits[1];

  // Train a RandomForest model.
  // Empty categoricalFeaturesInfo indicates all features are continuous.
  Integer numClasses = 2;
  HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
  Integer numTrees = 3; // Use more in practice.
  String featureSubsetStrategy = "auto"; // Let the algorithm choose.
  String impurity = "gini";
  Integer maxDepth = 5;
  Integer maxBins = 32;
  Integer seed = 12345;

  final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
    categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
    seed);

  // Evaluate model on test instances and compute test error
  JavaPairRDD<Double, Double> predictionAndLabel =
    testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
  Double testErr =
    1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
      @Override
      public Boolean call(Tuple2<Double, Double> pl) {
        return !pl._1().equals(pl._2());
      }
    }).count() / testData.count();
  System.out.println("Test Error: " + testErr);
  System.out.println("Learned classification forest model:\n" + model.toDebugString());

  // Save and load model
  model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel");
  RandomForestModel sameModel = RandomForestModel.load(jsc.sc(),
    "target/tmp/myRandomForestClassificationModel");
  // $example off$

  jsc.stop();
}
 
Example #4
Source File: RDFUpdate.java    From oryx with Apache License 2.0 4 votes vote down vote up
@Override
public PMML buildModel(JavaSparkContext sparkContext,
                       JavaRDD<String> trainData,
                       List<?> hyperParameters,
                       Path candidatePath) {

  int maxSplitCandidates = (Integer) hyperParameters.get(0);
  int maxDepth = (Integer) hyperParameters.get(1);
  String impurity = (String) hyperParameters.get(2);
  Preconditions.checkArgument(maxSplitCandidates >= 2,
                              "max-split-candidates must be at least 2");
  Preconditions.checkArgument(maxDepth > 0,
                              "max-depth must be at least 1");

  JavaRDD<String[]> parsedRDD = trainData.map(MLFunctions.PARSE_FN);
  CategoricalValueEncodings categoricalValueEncodings =
      new CategoricalValueEncodings(getDistinctValues(parsedRDD));
  JavaRDD<LabeledPoint> trainPointData =
      parseToLabeledPointRDD(parsedRDD, categoricalValueEncodings);

  Map<Integer,Integer> categoryInfo = categoricalValueEncodings.getCategoryCounts();
  categoryInfo.remove(inputSchema.getTargetFeatureIndex()); // Don't specify target count
  // Need to translate indices to predictor indices
  Map<Integer,Integer> categoryInfoByPredictor = new HashMap<>(categoryInfo.size());
  categoryInfo.forEach((k, v) -> categoryInfoByPredictor.put(inputSchema.featureToPredictorIndex(k), v));

  int seed = RandomManager.getRandom().nextInt();

  RandomForestModel model;
  if (inputSchema.isClassification()) {
    int numTargetClasses =
        categoricalValueEncodings.getValueCount(inputSchema.getTargetFeatureIndex());
    model = RandomForest.trainClassifier(trainPointData,
                                         numTargetClasses,
                                         categoryInfoByPredictor,
                                         numTrees,
                                         "auto",
                                         impurity,
                                         maxDepth,
                                         maxSplitCandidates,
                                         seed);
  } else {
    model = RandomForest.trainRegressor(trainPointData,
                                        categoryInfoByPredictor,
                                        numTrees,
                                        "auto",
                                        impurity,
                                        maxDepth,
                                        maxSplitCandidates,
                                        seed);
  }

  List<IntLongHashMap> treeNodeIDCounts = treeNodeExampleCounts(trainPointData, model);
  IntLongHashMap predictorIndexCounts = predictorExampleCounts(trainPointData, model);

  return rdfModelToPMML(model,
                        categoricalValueEncodings,
                        maxDepth,
                        maxSplitCandidates,
                        impurity,
                        treeNodeIDCounts,
                        predictorIndexCounts);
}