Java Code Examples for org.apache.spark.api.java.JavaRDD#cache()
The following examples show how to use
org.apache.spark.api.java.JavaRDD#cache() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: SparkCacheOperator.java From rheem with Apache License 2.0 | 6 votes |
@Override public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> evaluate( ChannelInstance[] inputs, ChannelInstance[] outputs, SparkExecutor sparkExecutor, OptimizationContext.OperatorContext operatorContext) { RddChannel.Instance input = (RddChannel.Instance) inputs[0]; final JavaRDD<Object> rdd = input.provideRdd(); final JavaRDD<Object> cachedRdd = rdd.cache(); cachedRdd.foreachPartition(iterator -> { }); RddChannel.Instance output = (RddChannel.Instance) outputs[0]; output.accept(cachedRdd, sparkExecutor); return ExecutionOperator.modelQuasiEagerExecution(inputs, outputs, operatorContext); }
Example 2
Source File: AnalyzeSpark.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> data, int maxHistogramBuckets) { data.cache(); /* * TODO: Some care should be given to add histogramBuckets and histogramBucketCounts to this in the future */ List<ColumnType> columnTypes = schema.getColumnTypes(); List<AnalysisCounter> counters = data.aggregate(null, new AnalysisAddFunction(schema), new AnalysisCombineFunction()); double[][] minsMaxes = new double[counters.size()][2]; List<ColumnAnalysis> list = DataVecAnalysisUtils.convertCounters(counters, minsMaxes, columnTypes); List<HistogramCounter> histogramCounters = data.aggregate(null, new HistogramAddFunction(maxHistogramBuckets, schema, minsMaxes), new HistogramCombineFunction()); DataVecAnalysisUtils.mergeCounters(list, histogramCounters); return new DataAnalysis(schema, list); }
Example 3
Source File: AssemblyContigAlignmentsConfigPicker.java From gatk with BSD 3-Clause "New" or "Revised" License | 6 votes |
/** * Parses input alignments into custom {@link AlignmentInterval} format, and * performs a primitive filtering implemented in * {@link #notDiscardForBadMQ(AlignedContig)} that * gets rid of contigs with no good alignments. * * It's important to remember that this step doesn't select alignments, * but only parses alignments and either keeps the whole contig or drops it completely. */ private static JavaRDD<AlignedContig> convertRawAlignmentsToAlignedContigAndFilterByQuality(final JavaRDD<GATKRead> assemblyAlignments, final SAMFileHeader header, final Logger toolLogger) { assemblyAlignments.cache(); toolLogger.info( "Processing " + assemblyAlignments.count() + " raw alignments from " + assemblyAlignments.map(GATKRead::getName).distinct().count() + " contigs."); final JavaRDD<AlignedContig> parsedContigAlignments = new SvDiscoverFromLocalAssemblyContigAlignmentsSpark.SAMFormattedContigAlignmentParser(assemblyAlignments, header, false) .getAlignedContigs() .filter(AssemblyContigAlignmentsConfigPicker::notDiscardForBadMQ).cache(); assemblyAlignments.unpersist(); toolLogger.info( "Filtering on MQ left " + parsedContigAlignments.count() + " contigs."); return parsedContigAlignments; }
Example 4
Source File: MLLibUtil.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Convert an rdd of data set in to labeled point. * @param data the dataset to convert * @param preCache boolean pre-cache rdd before operation * @return an rdd of labeled point */ public static JavaRDD<LabeledPoint> fromDataSet(JavaRDD<DataSet> data, boolean preCache) { if (preCache && !data.getStorageLevel().useMemory()) { data.cache(); } return data.map(new Function<DataSet, LabeledPoint>() { @Override public LabeledPoint call(DataSet dataSet) { return toLabeledPoint(dataSet); } }); }
Example 5
Source File: JavaGaussianMixtureExample.java From SparkDemo with MIT License | 5 votes |
public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaGaussianMixtureExample"); JavaSparkContext jsc = new JavaSparkContext(conf); // $example on$ // Load and parse data String path = "data/mllib/gmm_data.txt"; JavaRDD<String> data = jsc.textFile(path); JavaRDD<Vector> parsedData = data.map( new Function<String, Vector>() { public Vector call(String s) { String[] sarray = s.trim().split(" "); double[] values = new double[sarray.length]; for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); } return Vectors.dense(values); } } ); parsedData.cache(); // Cluster the data into two classes using GaussianMixture GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); // Save and load GaussianMixtureModel gmm.save(jsc.sc(), "target/org/apache/spark/JavaGaussianMixtureExample/GaussianMixtureModel"); GaussianMixtureModel sameModel = GaussianMixtureModel.load(jsc.sc(), "target/org.apache.spark.JavaGaussianMixtureExample/GaussianMixtureModel"); // Output the parameters of the mixture model for (int j = 0; j < gmm.k(); j++) { System.out.printf("weight=%f\nmu=%s\nsigma=\n%s\n", gmm.weights()[j], gmm.gaussians()[j].mu(), gmm.gaussians()[j].sigma()); } // $example off$ jsc.stop(); }
Example 6
Source File: MLLibUtil.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Converts JavaRDD labeled points to JavaRDD DataSets. * @param data JavaRDD LabeledPoints * @param numPossibleLabels number of possible labels * @param preCache boolean pre-cache rdd before operation * @return */ public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> data, final long numPossibleLabels, boolean preCache) { if (preCache && !data.getStorageLevel().useMemory()) { data.cache(); } return data.map(new Function<LabeledPoint, DataSet>() { @Override public DataSet call(LabeledPoint lp) { return fromLabeledPoint(lp, numPossibleLabels); } }); }
Example 7
Source File: MLLibUtil.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Converts a continuous JavaRDD LabeledPoint to a JavaRDD DataSet. * @param data JavaRdd LabeledPoint * @param preCache boolean pre-cache rdd before operation * @return */ public static JavaRDD<DataSet> fromContinuousLabeledPoint(JavaRDD<LabeledPoint> data, boolean preCache) { if (preCache && !data.getStorageLevel().useMemory()) { data.cache(); } return data.map(new Function<LabeledPoint, DataSet>() { @Override public DataSet call(LabeledPoint lp) { return convertToDataset(lp); } }); }
Example 8
Source File: CollectMultipleMetricsSpark.java From gatk with BSD 3-Clause "New" or "Revised" License | 5 votes |
@Override protected void runTool( final JavaSparkContext ctx ) { final JavaRDD<GATKRead> unFilteredReads = getUnfilteredReads(); List<SparkCollectorProvider> collectorsToRun = getCollectorsToRun(); if (collectorsToRun.size() > 1) { // if there is more than one collector to run, cache the // unfiltered RDD so we don't recompute it unFilteredReads.cache(); } for (final SparkCollectorProvider provider : collectorsToRun) { MetricsCollectorSpark<? extends MetricsArgumentCollection> metricsCollector = provider.createCollector( outputBaseName, metricAccumulationLevel.accumulationLevels, getDefaultHeaders(), getHeaderForReads() ); validateCollector(metricsCollector, collectorsToRun.get(collectorsToRun.indexOf(provider)).getClass().getName()); // Execute the collector's lifecycle //Bypass the framework merging of command line filters and just apply the default //ones specified by the collector ReadFilter readFilter = ReadFilter.fromList(metricsCollector.getDefaultReadFilters(), getHeaderForReads()); metricsCollector.collectMetrics( unFilteredReads.filter(r -> readFilter.test(r)), getHeaderForReads() ); metricsCollector.saveMetrics(getReadSourceName()); } }
Example 9
Source File: SparkSharder.java From gatk with BSD 3-Clause "New" or "Revised" License | 4 votes |
private static <L extends Locatable, I extends Locatable, T> JavaRDD<T> joinOverlapping(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass, SAMSequenceDictionary sequenceDictionary, JavaRDD<I> intervals, int maxLocatableLength, FlatMapFunction2<Iterator<L>, Iterator<I>, T> f) { List<PartitionLocatable<SimpleInterval>> partitionReadExtents = computePartitionReadExtents(locatables, sequenceDictionary, maxLocatableLength); List<SimpleInterval> firstLocatablesList = partitionReadExtents.stream().map(PartitionLocatable::getLocatable).collect(Collectors.toList()); Broadcast<List<SimpleInterval>> firstLocatablesBroadcast = ctx.broadcast(firstLocatablesList); // For each interval find which partition it starts and ends in. // An interval is processed in the partition it starts in. However, we need to make sure that // subsequent partitions are coalesced if needed, so for each partition p find the latest subsequent // partition that is needed to read all of the intervals that start in p. OverlapDetector<PartitionLocatable<SimpleInterval>> overlapDetector = OverlapDetector.create(partitionReadExtents); Broadcast<OverlapDetector<PartitionLocatable<SimpleInterval>>> overlapDetectorBroadcast = ctx.broadcast(overlapDetector); JavaRDD<PartitionLocatable<I>> indexedIntervals = intervals.map(interval -> { int[] partitionIndexes = overlapDetectorBroadcast.getValue().getOverlaps(interval).stream() .mapToInt(PartitionLocatable::getPartitionIndex).toArray(); if (partitionIndexes.length == 0) { final List<SimpleInterval> firstLocatables = firstLocatablesBroadcast.getValue(); // interval does not overlap any partition - add it to the one after the interval start int i = Collections.binarySearch(firstLocatables, new SimpleInterval(interval), (o1, o2) -> IntervalUtils.compareLocatables(o1, o2, sequenceDictionary)); if (i >= 0) { throw new IllegalStateException(); // TODO: no overlaps, yet start of interval matches a partition read extent start } int insertionPoint = -i - 1; if (insertionPoint == firstLocatables.size()) { insertionPoint = firstLocatables.size() - 1; } return new PartitionLocatable<>(insertionPoint, interval); } Arrays.sort(partitionIndexes); int startIndex = partitionIndexes[0]; int endIndex = partitionIndexes[partitionIndexes.length - 1]; return new PartitionLocatable<>(startIndex, endIndex, interval); }); // Create an RDD of intervals with the same number of partitions as the locatables, and where each interval // is in its start partition. Within each partition, intervals are sorted by IntervalUtils#compareLocatables. JavaRDD<PartitionLocatable<I>> indexedIntervalsRepartitioned = indexedIntervals .mapToPair(interval -> new Tuple2<>(interval, (Void) null)) .repartitionAndSortWithinPartitions(new PartitionLocatablePartitioner(locatables.getNumPartitions()), new PartitionLocatableComparator<I>(sequenceDictionary)) .keys(); indexedIntervalsRepartitioned.cache(); // cache since we need to do two calculations on the intervals // Find the end partition index for each partition. Map<Integer, Integer> maxEndPartitionIndexesMap = indexedIntervalsRepartitioned.mapToPair((PairFunction<PartitionLocatable<I>, Integer, Integer>) partitionLocatable -> new Tuple2<>(partitionLocatable.getPartitionIndex(), partitionLocatable.getEndPartitionIndex())) .reduceByKey((Function2<Integer, Integer, Integer>) Math::max) .collectAsMap(); List<Integer> maxEndPartitionIndexes = IntStream.range(0, locatables.getNumPartitions()).boxed().collect(Collectors.toList()); maxEndPartitionIndexesMap.forEach((startIndex, endIndex) -> { if (endIndex > maxEndPartitionIndexes.get(startIndex)) { maxEndPartitionIndexes.set(startIndex, endIndex); } }); JavaRDD<L> coalescedRdd = coalesce(locatables, locatableClass, new RangePartitionCoalescer(maxEndPartitionIndexes)); // zipPartitions on coalesced locatable partitions and intervals, and apply the function f return coalescedRdd.zipPartitions(indexedIntervalsRepartitioned.map(PartitionLocatable::getLocatable), f); }
Example 10
Source File: DataSparkFromRDD.java From toolbox with Apache License 2.0 | 4 votes |
public DataSparkFromRDD(JavaRDD<DataInstance> input, Attributes atts) { // FIXME: is this a good idea? amidstRDD = input.cache(); attributes = atts; }
Example 11
Source File: ALSUpdate.java From oryx with Apache License 2.0 | 4 votes |
@Override public PMML buildModel(JavaSparkContext sparkContext, JavaRDD<String> trainData, List<?> hyperParameters, Path candidatePath) { int features = (Integer) hyperParameters.get(0); double lambda = (Double) hyperParameters.get(1); double alpha = (Double) hyperParameters.get(2); double epsilon = Double.NaN; if (logStrength) { epsilon = (Double) hyperParameters.get(3); } Preconditions.checkArgument(features > 0); Preconditions.checkArgument(lambda >= 0.0); Preconditions.checkArgument(alpha > 0.0); if (logStrength) { Preconditions.checkArgument(epsilon > 0.0); } JavaRDD<String[]> parsedRDD = trainData.map(MLFunctions.PARSE_FN); parsedRDD.cache(); Map<String,Integer> userIDIndexMap = buildIDIndexMapping(parsedRDD, true); Map<String,Integer> itemIDIndexMap = buildIDIndexMapping(parsedRDD, false); log.info("Broadcasting ID-index mappings for {} users, {} items", userIDIndexMap.size(), itemIDIndexMap.size()); Broadcast<Map<String,Integer>> bUserIDToIndex = sparkContext.broadcast(userIDIndexMap); Broadcast<Map<String,Integer>> bItemIDToIndex = sparkContext.broadcast(itemIDIndexMap); JavaRDD<Rating> trainRatingData = parsedToRatingRDD(parsedRDD, bUserIDToIndex, bItemIDToIndex); trainRatingData = aggregateScores(trainRatingData, epsilon); ALS als = new ALS() .setRank(features) .setIterations(iterations) .setLambda(lambda) .setCheckpointInterval(5); if (implicit) { als = als.setImplicitPrefs(true).setAlpha(alpha); } RDD<Rating> trainingRatingDataRDD = trainRatingData.rdd(); trainingRatingDataRDD.cache(); MatrixFactorizationModel model = als.run(trainingRatingDataRDD); trainingRatingDataRDD.unpersist(false); bUserIDToIndex.unpersist(); bItemIDToIndex.unpersist(); parsedRDD.unpersist(); Broadcast<Map<Integer,String>> bUserIndexToID = sparkContext.broadcast(invertMap(userIDIndexMap)); Broadcast<Map<Integer,String>> bItemIndexToID = sparkContext.broadcast(invertMap(itemIDIndexMap)); PMML pmml = mfModelToPMML(model, features, lambda, alpha, epsilon, implicit, logStrength, candidatePath, bUserIndexToID, bItemIndexToID); unpersist(model); bUserIndexToID.unpersist(); bItemIndexToID.unpersist(); return pmml; }
Example 12
Source File: ALSUpdate.java From oryx with Apache License 2.0 | 4 votes |
@Override public double evaluate(JavaSparkContext sparkContext, PMML model, Path modelParentPath, JavaRDD<String> testData, JavaRDD<String> trainData) { JavaRDD<String[]> parsedTestRDD = testData.map(MLFunctions.PARSE_FN); parsedTestRDD.cache(); Map<String,Integer> userIDToIndex = buildIDIndexOneWayMap(model, parsedTestRDD, true); Map<String,Integer> itemIDToIndex = buildIDIndexOneWayMap(model, parsedTestRDD, false); log.info("Broadcasting ID-index mappings for {} users, {} items", userIDToIndex.size(), itemIDToIndex.size()); Broadcast<Map<String,Integer>> bUserIDToIndex = sparkContext.broadcast(userIDToIndex); Broadcast<Map<String,Integer>> bItemIDToIndex = sparkContext.broadcast(itemIDToIndex); JavaRDD<Rating> testRatingData = parsedToRatingRDD(parsedTestRDD, bUserIDToIndex, bItemIDToIndex); double epsilon = Double.NaN; if (logStrength) { epsilon = Double.parseDouble(AppPMMLUtils.getExtensionValue(model, "epsilon")); } testRatingData = aggregateScores(testRatingData, epsilon); MatrixFactorizationModel mfModel = pmmlToMFModel(sparkContext, model, modelParentPath, bUserIDToIndex, bItemIDToIndex); parsedTestRDD.unpersist(); double eval; if (implicit) { double auc = Evaluation.areaUnderCurve(sparkContext, mfModel, testRatingData); log.info("AUC: {}", auc); eval = auc; } else { double rmse = Evaluation.rmse(mfModel, testRatingData); log.info("RMSE: {}", rmse); eval = -rmse; } unpersist(mfModel); bUserIDToIndex.unpersist(); bItemIDToIndex.unpersist(); return eval; }
Example 13
Source File: MLUpdate.java From oryx with Apache License 2.0 | 4 votes |
@Override public void runUpdate(JavaSparkContext sparkContext, long timestamp, JavaPairRDD<Object,M> newKeyMessageData, JavaPairRDD<Object,M> pastKeyMessageData, String modelDirString, TopicProducer<String,String> modelUpdateTopic) throws IOException, InterruptedException { Objects.requireNonNull(newKeyMessageData); JavaRDD<M> newData = newKeyMessageData.values(); JavaRDD<M> pastData = pastKeyMessageData == null ? null : pastKeyMessageData.values(); if (newData != null) { newData.cache(); // This forces caching of the RDD. This shouldn't be necessary but we see some freezes // when many workers try to materialize the RDDs at once. Hence the workaround. newData.foreachPartition(p -> {}); } if (pastData != null) { pastData.cache(); pastData.foreachPartition(p -> {}); } List<List<?>> hyperParameterCombos = HyperParams.chooseHyperParameterCombos( getHyperParameterValues(), hyperParamSearch, candidates); Path modelDir = new Path(modelDirString); Path tempModelPath = new Path(modelDir, ".temporary"); Path candidatesPath = new Path(tempModelPath, Long.toString(System.currentTimeMillis())); FileSystem fs = FileSystem.get(modelDir.toUri(), sparkContext.hadoopConfiguration()); fs.mkdirs(candidatesPath); Path bestCandidatePath = findBestCandidatePath( sparkContext, newData, pastData, hyperParameterCombos, candidatesPath); Path finalPath = new Path(modelDir, Long.toString(System.currentTimeMillis())); if (bestCandidatePath == null) { log.info("Unable to build any model"); } else { // Move best model into place fs.rename(bestCandidatePath, finalPath); } // Then delete everything else fs.delete(candidatesPath, true); if (modelUpdateTopic == null) { log.info("No update topic configured, not publishing models to a topic"); } else { // Push PMML model onto update topic, if it exists Path bestModelPath = new Path(finalPath, MODEL_FILE_NAME); if (fs.exists(bestModelPath)) { FileStatus bestModelPathFS = fs.getFileStatus(bestModelPath); PMML bestModel = null; boolean modelNeededForUpdates = canPublishAdditionalModelData(); boolean modelNotTooLarge = bestModelPathFS.getLen() <= maxMessageSize; if (modelNeededForUpdates || modelNotTooLarge) { // Either the model is required for publishAdditionalModelData, or required because it's going to // be serialized to Kafka try (InputStream in = fs.open(bestModelPath)) { bestModel = PMMLUtils.read(in); } } if (modelNotTooLarge) { modelUpdateTopic.send("MODEL", PMMLUtils.toString(bestModel)); } else { modelUpdateTopic.send("MODEL-REF", fs.makeQualified(bestModelPath).toString()); } if (modelNeededForUpdates) { publishAdditionalModelData( sparkContext, bestModel, newData, pastData, finalPath, modelUpdateTopic); } } } if (newData != null) { newData.unpersist(); } if (pastData != null) { pastData.unpersist(); } }
Example 14
Source File: AnalyzeSpark.java From DataVec with Apache License 2.0 | 4 votes |
/** * * @param schema * @param data * @return */ public static DataQualityAnalysis analyzeQuality(final Schema schema, final JavaRDD<List<Writable>> data) { data.cache(); int nColumns = schema.numColumns(); List<ColumnType> columnTypes = schema.getColumnTypes(); List<QualityAnalysisState> states = data.aggregate(null, new QualityAnalysisAddFunction(schema), new QualityAnalysisCombineFunction()); List<ColumnQuality> list = new ArrayList<>(nColumns); for (QualityAnalysisState qualityState : states) { list.add(qualityState.getColumnQuality()); } return new DataQualityAnalysis(schema, list); }
Example 15
Source File: JavaLinearRegressionWithSGDExample.java From SparkDemo with MIT License | 4 votes |
public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithSGDExample"); JavaSparkContext sc = new JavaSparkContext(conf); // $example on$ // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD<String> data = sc.textFile(path); JavaRDD<LabeledPoint> parsedData = data.map( new Function<String, LabeledPoint>() { public LabeledPoint call(String line) { String[] parts = line.split(","); String[] features = parts[1].split(" "); double[] v = new double[features.length]; for (int i = 0; i < features.length - 1; i++) { v[i] = Double.parseDouble(features[i]); } return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); } } ); parsedData.cache(); // Building the model int numIterations = 100; double stepSize = 0.00000001; final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); // Evaluate model on training examples and compute training error JavaRDD<Tuple2<Double, Double>> valuesAndPreds = parsedData.map( new Function<LabeledPoint, Tuple2<Double, Double>>() { public Tuple2<Double, Double> call(LabeledPoint point) { double prediction = model.predict(point.features()); return new Tuple2<>(prediction, point.label()); } } ); double MSE = new JavaDoubleRDD(valuesAndPreds.map( new Function<Tuple2<Double, Double>, Object>() { public Object call(Tuple2<Double, Double> pair) { return Math.pow(pair._1() - pair._2(), 2.0); } } ).rdd()).mean(); System.out.println("training Mean Squared Error = " + MSE); // Save and load model model.save(sc.sc(), "target/tmp/javaLinearRegressionWithSGDModel"); LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "target/tmp/javaLinearRegressionWithSGDModel"); // $example off$ sc.stop(); }
Example 16
Source File: JavaKMeansExample.java From SparkDemo with MIT License | 4 votes |
public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaKMeansExample"); JavaSparkContext jsc = new JavaSparkContext(conf); // $example on$ // Load and parse data String path = "data/mllib/kmeans_data.txt"; JavaRDD<String> data = jsc.textFile(path); JavaRDD<Vector> parsedData = data.map( new Function<String, Vector>() { public Vector call(String s) { String[] sarray = s.split(" "); double[] values = new double[sarray.length]; for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); } return Vectors.dense(values); } } ); parsedData.cache(); // Cluster the data into two classes using KMeans int numClusters = 2; int numIterations = 20; KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations); System.out.println("Cluster centers:"); for (Vector center: clusters.clusterCenters()) { System.out.println(" " + center); } double cost = clusters.computeCost(parsedData.rdd()); System.out.println("Cost: " + cost); // Evaluate clustering by computing Within Set Sum of Squared Errors double WSSSE = clusters.computeCost(parsedData.rdd()); System.out.println("Within Set Sum of Squared Errors = " + WSSSE); // Save and load model clusters.save(jsc.sc(), "target/org/apache/spark/JavaKMeansExample/KMeansModel"); KMeansModel sameModel = KMeansModel.load(jsc.sc(), "target/org/apache/spark/JavaKMeansExample/KMeansModel"); // $example off$ jsc.stop(); }
Example 17
Source File: JavaSVMWithSGDExample.java From SparkDemo with MIT License | 4 votes |
public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaSVMWithSGDExample"); SparkContext sc = new SparkContext(conf); // $example on$ String path = "data/mllib/sample_libsvm_data.txt"; JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); // Split initial RDD into two... [60% training data, 40% testing data]. JavaRDD<LabeledPoint> training = data.sample(false, 0.6, 11L); training.cache(); JavaRDD<LabeledPoint> test = data.subtract(training); // Run training algorithm to build the model. int numIterations = 100; final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); // Clear the default threshold. model.clearThreshold(); // Compute raw scores on the test set. JavaRDD<Tuple2<Object, Object>> scoreAndLabels = test.map( new Function<LabeledPoint, Tuple2<Object, Object>>() { public Tuple2<Object, Object> call(LabeledPoint p) { Double score = model.predict(p.features()); return new Tuple2<Object, Object>(score, p.label()); } } ); // Get evaluation metrics. BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); double auROC = metrics.areaUnderROC(); System.out.println("Area under ROC = " + auROC); // Save and load model model.save(sc, "target/tmp/javaSVMWithSGDModel"); SVMModel sameModel = SVMModel.load(sc, "target/tmp/javaSVMWithSGDModel"); // $example off$ sc.stop(); }