Java Code Examples for org.apache.flink.api.java.operators.IterativeDataSet#closeWith()
The following examples show how to use
org.apache.flink.api.java.operators.IterativeDataSet#closeWith() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BranchingPlansCompilerTest.java From flink with Apache License 2.0 | 6 votes |
@Test public void testBranchAfterIteration() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> sourceA = env.generateSequence(0,1); IterativeDataSet<Long> loopHead = sourceA.iterate(10); DataSet<Long> loopTail = loopHead.map(new IdentityMapper<Long>()).name("Mapper"); DataSet<Long> loopRes = loopHead.closeWith(loopTail); loopRes.output(new DiscardingOutputFormat<Long>()); loopRes.map(new IdentityMapper<Long>()) .output(new DiscardingOutputFormat<Long>()); Plan plan = env.createProgramPlan(); try { compileNoStats(plan); } catch (Exception e) { e.printStackTrace(); Assert.fail(e.getMessage()); } }
Example 2
Source File: CollectionExecutionIterationTest.java From flink with Apache License 2.0 | 6 votes |
@Test public void testBulkIteration() { try { ExecutionEnvironment env = ExecutionEnvironment.createCollectionsEnvironment(); IterativeDataSet<Integer> iteration = env.fromElements(1).iterate(10); DataSet<Integer> result = iteration.closeWith(iteration.map(new AddSuperstepNumberMapper())); List<Integer> collected = new ArrayList<Integer>(); result.output(new LocalCollectionOutputFormat<Integer>(collected)); env.execute(); assertEquals(1, collected.size()); assertEquals(56, collected.get(0).intValue()); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } }
Example 3
Source File: IterationsCompilerTest.java From flink with Apache License 2.0 | 5 votes |
public static DataSet<Tuple2<Long, Long>> doSimpleBulkIteration(DataSet<Tuple2<Long, Long>> vertices, DataSet<Tuple2<Long, Long>> edges) { // open a bulk iteration IterativeDataSet<Tuple2<Long, Long>> iteration = vertices.iterate(20); DataSet<Tuple2<Long, Long>> changes = iteration .join(edges).where(0).equalTo(0) .flatMap(new FlatMapJoin()); // close the bulk iteration return iteration.closeWith(changes); }
Example 4
Source File: BranchingPlansCompilerTest.java From flink with Apache License 2.0 | 5 votes |
/** * Test to ensure that sourceA is inside as well as outside of the iteration the same * node. * * <pre> * (SRC A) (SRC B) * / \ / \ * (SINK 1) (ITERATION) | (SINK 2) * / \ / * (SINK 3) (CROSS => NEXT PARTIAL SOLUTION) * </pre> */ @Test public void testClosure() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(DEFAULT_PARALLELISM); DataSet<Long> sourceA = env.generateSequence(0,1); DataSet<Long> sourceB = env.generateSequence(0,1); sourceA.output(new DiscardingOutputFormat<Long>()); sourceB.output(new DiscardingOutputFormat<Long>()); IterativeDataSet<Long> loopHead = sourceA.iterate(10).name("Loop"); DataSet<Long> loopTail = loopHead.cross(sourceB).with(new IdentityCrosser<Long>()); DataSet<Long> loopRes = loopHead.closeWith(loopTail); loopRes.output(new DiscardingOutputFormat<Long>()); Plan plan = env.createProgramPlan(); try{ compileNoStats(plan); }catch(Exception e){ e.printStackTrace(); Assert.fail(e.getMessage()); } }
Example 5
Source File: SuccessAfterNetworkBuffersFailureITCase.java From flink with Apache License 2.0 | 5 votes |
private static void runKMeans(ExecutionEnvironment env) throws Exception { env.setParallelism(PARALLELISM); env.getConfig().disableSysoutLogging(); // get input data DataSet<KMeans.Point> points = KMeansData.getDefaultPointDataSet(env).rebalance(); DataSet<KMeans.Centroid> centroids = KMeansData.getDefaultCentroidDataSet(env).rebalance(); // set number of bulk iterations for KMeans algorithm IterativeDataSet<KMeans.Centroid> loop = centroids.iterate(20); // add some re-partitions to increase network buffer use DataSet<KMeans.Centroid> newCentroids = points // compute closest centroid for each point .map(new KMeans.SelectNearestCenter()).withBroadcastSet(loop, "centroids") .rebalance() // count and sum point coordinates for each centroid .map(new KMeans.CountAppender()) .groupBy(0).reduce(new KMeans.CentroidAccumulator()) // compute new centroids from point counts and coordinate sums .rebalance() .map(new KMeans.CentroidAverager()); // feed new centroids back into next iteration DataSet<KMeans.Centroid> finalCentroids = loop.closeWith(newCentroids); DataSet<Tuple2<Integer, KMeans.Point>> clusteredPoints = points // assign points to final clusters .map(new KMeans.SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids"); clusteredPoints.output(new DiscardingOutputFormat<Tuple2<Integer, KMeans.Point>>()); env.execute("KMeans Example"); }
Example 6
Source File: IterationIncompleteStaticPathConsumptionITCase.java From flink with Apache License 2.0 | 5 votes |
@Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); // the test data is constructed such that the merge join zig zag // has an early out, leaving elements on the static path input unconsumed DataSet<Path> edges = env.fromElements( new Path(2, 1), new Path(4, 1), new Path(6, 3), new Path(8, 3), new Path(10, 1), new Path(12, 1), new Path(14, 3), new Path(16, 3), new Path(18, 1), new Path(20, 1)); IterativeDataSet<Path> currentPaths = edges.iterate(10); DataSet<Path> newPaths = currentPaths .join(edges, JoinHint.REPARTITION_SORT_MERGE).where("to").equalTo("from") .with(new PathConnector()) .union(currentPaths).distinct("from", "to"); DataSet<Path> result = currentPaths.closeWith(newPaths); result.output(new DiscardingOutputFormat<Path>()); env.execute(); }
Example 7
Source File: IterationsCompilerTest.java From flink with Apache License 2.0 | 5 votes |
public static DataSet<Tuple2<Long, Long>> doSimpleBulkIteration(DataSet<Tuple2<Long, Long>> vertices, DataSet<Tuple2<Long, Long>> edges) { // open a bulk iteration IterativeDataSet<Tuple2<Long, Long>> iteration = vertices.iterate(20); DataSet<Tuple2<Long, Long>> changes = iteration .join(edges).where(0).equalTo(0) .flatMap(new FlatMapJoin()); // close the bulk iteration return iteration.closeWith(changes); }
Example 8
Source File: NestedIterationsTest.java From Flink-CEPplus with Apache License 2.0 | 5 votes |
@Test public void testBulkIterationInClosure() { try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Long> data1 = env.generateSequence(1, 100); DataSet<Long> data2 = env.generateSequence(1, 100); IterativeDataSet<Long> firstIteration = data1.iterate(100); DataSet<Long> firstResult = firstIteration.closeWith(firstIteration.map(new IdentityMapper<Long>())); IterativeDataSet<Long> mainIteration = data2.map(new IdentityMapper<Long>()).iterate(100); DataSet<Long> joined = mainIteration.join(firstResult) .where(new IdentityKeyExtractor<Long>()).equalTo(new IdentityKeyExtractor<Long>()) .with(new DummyFlatJoinFunction<Long>()); DataSet<Long> mainResult = mainIteration.closeWith(joined); mainResult.output(new DiscardingOutputFormat<Long>()); Plan p = env.createProgramPlan(); // optimizer should be able to translate this OptimizedPlan op = compileNoStats(p); // job graph generator should be able to translate this new JobGraphGenerator().compileJobGraph(op); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } }
Example 9
Source File: DistributedVI.java From toolbox with Apache License 2.0 | 4 votes |
@Override public double updateModel(DataFlink<DataInstance> dataUpdate){ try{ final ExecutionEnvironment env = dataUpdate.getDataSet().getExecutionEnvironment(); // get input data CompoundVector parameterPrior = this.svb.getNaturalParameterPrior(); DataSet<CompoundVector> paramSet = env.fromElements(parameterPrior); ConvergenceCriterion convergenceELBO; if(timeLimit == -1) { convergenceELBO = new ConvergenceELBO(this.globalThreshold, System.nanoTime()); } else { convergenceELBO = new ConvergenceELBObyTime(this.timeLimit, System.nanoTime()); this.setMaximumGlobalIterations(5000); } // set number of bulk iterations for KMeans algorithm IterativeDataSet<CompoundVector> loop = paramSet.iterate(maximumGlobalIterations) .registerAggregationConvergenceCriterion("ELBO_" + this.dag.getName(), new DoubleSumAggregator(),convergenceELBO); Configuration config = new Configuration(); config.setString(ParameterLearningAlgorithm.BN_NAME, this.dag.getName()); config.setBytes(SVB, Serialization.serializeObject(svb)); //We add an empty batched data set to emit the updated prior. DataOnMemory<DataInstance> emtpyBatch = new DataOnMemoryListContainer<DataInstance>(dataUpdate.getAttributes()); DataSet<DataOnMemory<DataInstance>> unionData = null; unionData = dataUpdate.getBatchedDataSet(this.batchSize) .union(env.fromCollection(Arrays.asList(emtpyBatch), TypeExtractor.getForClass((Class<DataOnMemory<DataInstance>>) Class.forName("eu.amidst.core.datastream.DataOnMemory")))); DataSet<CompoundVector> newparamSet = unionData .map(new ParallelVBMap(randomStart, idenitifableModelling)) .withParameters(config) .withBroadcastSet(loop, "VB_PARAMS_" + this.dag.getName()) .reduce(new ParallelVBReduce()); // feed new centroids back into next iteration DataSet<CompoundVector> finlparamSet = loop.closeWith(newparamSet); parameterPrior = finlparamSet.collect().get(0); this.svb.updateNaturalParameterPosteriors(parameterPrior); this.svb.updateNaturalParameterPrior(parameterPrior); if(timeLimit == -1) this.globalELBO = ((ConvergenceELBO)loop.getAggregators().getConvergenceCriterion()).getELBO(); else this.globalELBO = ((ConvergenceELBObyTime)loop.getAggregators().getConvergenceCriterion()).getELBO(); this.svb.applyTransition(); }catch(Exception ex){ throw new RuntimeException(ex.getMessage()); } this.randomStart=false; return this.getLogMarginalProbability(); }
Example 10
Source File: IterationsCompilerTest.java From flink with Apache License 2.0 | 4 votes |
@Test public void testResetPartialSolution() { try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Long> width = env.generateSequence(1, 10); DataSet<Long> update = env.generateSequence(1, 10); DataSet<Long> lastGradient = env.generateSequence(1, 10); DataSet<Long> init = width.union(update).union(lastGradient); IterativeDataSet<Long> iteration = init.iterate(10); width = iteration.filter(new IdFilter<Long>()); update = iteration.filter(new IdFilter<Long>()); lastGradient = iteration.filter(new IdFilter<Long>()); DataSet<Long> gradient = width.map(new IdentityMapper<Long>()); DataSet<Long> term = gradient.join(lastGradient) .where(new IdentityKeyExtractor<Long>()) .equalTo(new IdentityKeyExtractor<Long>()) .with(new JoinFunction<Long, Long, Long>() { public Long join(Long first, Long second) { return null; } }); update = update.map(new RichMapFunction<Long, Long>() { public Long map(Long value) { return null; } }).withBroadcastSet(term, "some-name"); DataSet<Long> result = iteration.closeWith(width.union(update).union(lastGradient)); result.output(new DiscardingOutputFormat<Long>()); Plan p = env.createProgramPlan(); OptimizedPlan op = compileNoStats(p); new JobGraphGenerator().compileJobGraph(op); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } }
Example 11
Source File: KMeans.java From toolbox with Apache License 2.0 | 4 votes |
public static void main(String[] args) throws Exception { if(!parseParameters(args)) { return; } // set up execution environment ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); // get input data DataSet<Point> points = getPointDataSet(env); DataSet<Centroid> centroids = getCentroidDataSet(env); // set number of bulk iterations for KMeans algorithm IterativeDataSet<Centroid> loop = centroids.iterate(numIterations); DataSet<Centroid> newCentroids = points // compute closest centroid for each point .map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids") // count and sumNonStateless point coordinates for each centroid .map(new CountAppender()) .groupBy(0).reduce(new CentroidAccumulator()) // compute new centroids from point counts and coordinate sums .map(new CentroidAverager()); // feed new centroids back into next iteration DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids); DataSet<Tuple2<Integer, Point>> clusteredPoints = points // assign points to final clusters .map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids"); // emit result if (fileOutput) { clusteredPoints.writeAsCsv(outputPath, "\n", " "); // since file sinks are lazy, we trigger the execution explicitly env.execute("KMeans Example"); } else { clusteredPoints.print(); } }
Example 12
Source File: GlmUtil.java From Alink with Apache License 2.0 | 4 votes |
/** * * @param data: input DataSet, format by preProc. * @param numFeature: number of features. * @param familyLink: family and link function. * @param regParam: L2. * @param fitIntercept: If true, fit intercept. If false, not fit intercept. * @param numIter: number of iter. * @param epsilon: epsilon. * @return DataSet of WeightedLeastSquaresModel. */ public static DataSet<WeightedLeastSquaresModel> train(DataSet<Row> data, int numFeature, FamilyLink familyLink, double regParam, boolean fitIntercept, int numIter, double epsilon) { String familyName = familyLink.getFamilyName(); String linkName = familyLink.getLinkName(); DataSet<WeightedLeastSquaresModel> finalModel = null; if (familyName.toLowerCase().equals("gaussian") && linkName.toLowerCase().equals("identity")) { finalModel = data.map(new GaussianDataProc(numFeature)) .mapPartition(new LocalWeightStat(numFeature)).name("init LocalWeightStat") .reduce(new GlobalWeightStat()).name("init GlobalWeightStat") .mapPartition(new WeightedLeastSquares(fitIntercept, regParam, true, true)) .setParallelism(1).name("init WeightedLeastSquares"); } else { DataSet<WeightedLeastSquaresModel> initModel = data .map(new InitData(familyLink, numFeature)) .mapPartition(new LocalWeightStat(numFeature)).name("init LocalWeightStat") .reduce(new GlobalWeightStat()).name("init GlobalWeightStat") .mapPartition(new WeightedLeastSquares(fitIntercept, regParam, true, true)) .setParallelism(1).name("init WeightedLeastSquares"); IterativeDataSet<WeightedLeastSquaresModel> loop = initModel.iterate(numIter).name("loop"); DataSet<WeightedLeastSquaresModel> updateIrlsModel = data .map(new UpdateData(familyLink, numFeature + 3)).name("UpdateData") .withBroadcastSet(loop, "model") .mapPartition(new LocalWeightStat(numFeature)).name("localWeightStat") .reduce(new GlobalWeightStat()).name("GlobalWeightStat") .mapPartition(new WeightedLeastSquares(fitIntercept, regParam, false, false)) .setParallelism(1).name("WLS"); //converge DataSet<Tuple2<WeightedLeastSquaresModel, WeightedLeastSquaresModel>> join = loop.map(new ModelAddId()) .join(updateIrlsModel.map(new ModelAddId())) .where(0).equalTo(0).projectFirst(1).projectSecond(1); FilterFunction<Tuple2<WeightedLeastSquaresModel, WeightedLeastSquaresModel>> filterCriterion = new IterCriterion(epsilon); DataSet<Tuple2<WeightedLeastSquaresModel, WeightedLeastSquaresModel>> criterion = join.filter(filterCriterion); finalModel = loop.closeWith(updateIrlsModel, criterion); } return finalModel; }
Example 13
Source File: PageRank.java From flink with Apache License 2.0 | 4 votes |
public static void main(String[] args) throws Exception { ParameterTool params = ParameterTool.fromArgs(args); final int numPages = params.getInt("numPages", PageRankData.getNumberOfPages()); final int maxIterations = params.getInt("iterations", 10); // set up execution environment final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); // make the parameters available to the web ui env.getConfig().setGlobalJobParameters(params); // get input data DataSet<Long> pagesInput = getPagesDataSet(env, params); DataSet<Tuple2<Long, Long>> linksInput = getLinksDataSet(env, params); // assign initial rank to pages DataSet<Tuple2<Long, Double>> pagesWithRanks = pagesInput. map(new RankAssigner((1.0d / numPages))); // build adjacency list from link input DataSet<Tuple2<Long, Long[]>> adjacencyListInput = linksInput.groupBy(0).reduceGroup(new BuildOutgoingEdgeList()); // set iterative data set IterativeDataSet<Tuple2<Long, Double>> iteration = pagesWithRanks.iterate(maxIterations); DataSet<Tuple2<Long, Double>> newRanks = iteration // join pages with outgoing edges and distribute rank .join(adjacencyListInput).where(0).equalTo(0).flatMap(new JoinVertexWithEdgesMatch()) // collect and sum ranks .groupBy(0).aggregate(SUM, 1) // apply dampening factor .map(new Dampener(DAMPENING_FACTOR, numPages)); DataSet<Tuple2<Long, Double>> finalPageRanks = iteration.closeWith( newRanks, newRanks.join(iteration).where(0).equalTo(0) // termination condition .filter(new EpsilonFilter())); // emit result if (params.has("output")) { finalPageRanks.writeAsCsv(params.get("output"), "\n", " "); // execute program env.execute("Basic Page Rank Example"); } else { System.out.println("Printing result to stdout. Use --output to specify output path."); finalPageRanks.print(); } }
Example 14
Source File: PageRankCompilerTest.java From flink with Apache License 2.0 | 4 votes |
@Test public void testPageRank() { try { final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); // get input data DataSet<Long> pagesInput = env.fromElements(1L); @SuppressWarnings("unchecked") DataSet<Tuple2<Long, Long>> linksInput = env.fromElements(new Tuple2<Long, Long>(1L, 2L)); // assign initial rank to pages DataSet<Tuple2<Long, Double>> pagesWithRanks = pagesInput. map(new RankAssigner((1.0d / 10))); // build adjacency list from link input DataSet<Tuple2<Long, Long[]>> adjacencyListInput = linksInput.groupBy(0).reduceGroup(new BuildOutgoingEdgeList()); // set iterative data set IterativeDataSet<Tuple2<Long, Double>> iteration = pagesWithRanks.iterate(10); Configuration cfg = new Configuration(); cfg.setString(Optimizer.HINT_LOCAL_STRATEGY, Optimizer.HINT_LOCAL_STRATEGY_HASH_BUILD_SECOND); DataSet<Tuple2<Long, Double>> newRanks = iteration // join pages with outgoing edges and distribute rank .join(adjacencyListInput).where(0).equalTo(0).withParameters(cfg) .flatMap(new JoinVertexWithEdgesMatch()) // collect and sum ranks .groupBy(0).aggregate(SUM, 1) // apply dampening factor .map(new Dampener(0.85, 10)); DataSet<Tuple2<Long, Double>> finalPageRanks = iteration.closeWith( newRanks, newRanks.join(iteration).where(0).equalTo(0) // termination condition .filter(new EpsilonFilter())); finalPageRanks.output(new DiscardingOutputFormat<Tuple2<Long, Double>>()); // get the plan and compile it Plan p = env.createProgramPlan(); OptimizedPlan op = compileNoStats(p); SinkPlanNode sinkPlanNode = (SinkPlanNode) op.getDataSinks().iterator().next(); BulkIterationPlanNode iterPlanNode = (BulkIterationPlanNode) sinkPlanNode.getInput().getSource(); // check that the partitioning is pushed out of the first loop Assert.assertEquals(ShipStrategyType.PARTITION_HASH, iterPlanNode.getInput().getShipStrategy()); Assert.assertEquals(LocalStrategy.NONE, iterPlanNode.getInput().getLocalStrategy()); BulkPartialSolutionPlanNode partSolPlanNode = iterPlanNode.getPartialSolutionPlanNode(); Assert.assertEquals(ShipStrategyType.FORWARD, partSolPlanNode.getOutgoingChannels().get(0).getShipStrategy()); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } }
Example 15
Source File: IterationsCompilerTest.java From Flink-CEPplus with Apache License 2.0 | 4 votes |
@Test public void testResetPartialSolution() { try { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Long> width = env.generateSequence(1, 10); DataSet<Long> update = env.generateSequence(1, 10); DataSet<Long> lastGradient = env.generateSequence(1, 10); DataSet<Long> init = width.union(update).union(lastGradient); IterativeDataSet<Long> iteration = init.iterate(10); width = iteration.filter(new IdFilter<Long>()); update = iteration.filter(new IdFilter<Long>()); lastGradient = iteration.filter(new IdFilter<Long>()); DataSet<Long> gradient = width.map(new IdentityMapper<Long>()); DataSet<Long> term = gradient.join(lastGradient) .where(new IdentityKeyExtractor<Long>()) .equalTo(new IdentityKeyExtractor<Long>()) .with(new JoinFunction<Long, Long, Long>() { public Long join(Long first, Long second) { return null; } }); update = update.map(new RichMapFunction<Long, Long>() { public Long map(Long value) { return null; } }).withBroadcastSet(term, "some-name"); DataSet<Long> result = iteration.closeWith(width.union(update).union(lastGradient)); result.output(new DiscardingOutputFormat<Long>()); Plan p = env.createProgramPlan(); OptimizedPlan op = compileNoStats(p); new JobGraphGenerator().compileJobGraph(op); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } }
Example 16
Source File: LinearRegression.java From flink with Apache License 2.0 | 4 votes |
public static void main(String[] args) throws Exception { final ParameterTool params = ParameterTool.fromArgs(args); // set up execution environment final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); final int iterations = params.getInt("iterations", 10); // make parameters available in the web interface env.getConfig().setGlobalJobParameters(params); // get input x data from elements DataSet<Data> data; if (params.has("input")) { // read data from CSV file data = env.readCsvFile(params.get("input")) .fieldDelimiter(" ") .includeFields(true, true) .pojoType(Data.class); } else { System.out.println("Executing LinearRegression example with default input data set."); System.out.println("Use --input to specify file input."); data = LinearRegressionData.getDefaultDataDataSet(env); } // get the parameters from elements DataSet<Params> parameters = LinearRegressionData.getDefaultParamsDataSet(env); // set number of bulk iterations for SGD linear Regression IterativeDataSet<Params> loop = parameters.iterate(iterations); DataSet<Params> newParameters = data // compute a single step using every sample .map(new SubUpdate()).withBroadcastSet(loop, "parameters") // sum up all the steps .reduce(new UpdateAccumulator()) // average the steps and update all parameters .map(new Update()); // feed new parameters back into next iteration DataSet<Params> result = loop.closeWith(newParameters); // emit result if (params.has("output")) { result.writeAsText(params.get("output")); // execute program env.execute("Linear Regression example"); } else { System.out.println("Printing result to stdout. Use --output to specify output path."); result.print(); } }
Example 17
Source File: dVMPv1.java From toolbox with Apache License 2.0 | 4 votes |
public double updateModel(DataFlink<DataInstance> dataUpdate){ try{ final ExecutionEnvironment env = dataUpdate.getDataSet().getExecutionEnvironment(); // get input data CompoundVector parameterPrior = this.svb.getNaturalParameterPrior(); DataSet<CompoundVector> paramSet = env.fromElements(parameterPrior); ConvergenceCriterion convergenceELBO; if(timeLimit == -1) { convergenceELBO = new ConvergenceELBO(this.globalThreshold, System.nanoTime()); } else { convergenceELBO = new ConvergenceELBObyTime(this.timeLimit, System.nanoTime(), this.idenitifableModelling.getNumberOfEpochs()); this.setMaximumGlobalIterations(5000); } // set number of bulk iterations for KMeans algorithm IterativeDataSet<CompoundVector> loop = paramSet.iterate(maximumGlobalIterations) .registerAggregationConvergenceCriterion("ELBO_" + this.dag.getName(), new DoubleSumAggregator(),convergenceELBO); Configuration config = new Configuration(); config.setString(ParameterLearningAlgorithm.BN_NAME, this.dag.getName()); config.setBytes(SVB, Serialization.serializeObject(svb)); //We add an empty batched data set to emit the updated prior. DataOnMemory<DataInstance> emtpyBatch = new DataOnMemoryListContainer<DataInstance>(dataUpdate.getAttributes()); DataSet<DataOnMemory<DataInstance>> unionData = null; unionData = dataUpdate.getBatchedDataSet(this.batchSize) .union(env.fromCollection(Arrays.asList(emtpyBatch), TypeExtractor.getForClass((Class<DataOnMemory<DataInstance>>) Class.forName("eu.amidst.core.datastream.DataOnMemory")))); DataSet<CompoundVector> newparamSet = unionData .map(new ParallelVBMap(randomStart, idenitifableModelling)) .withParameters(config) .withBroadcastSet(loop, "VB_PARAMS_" + this.dag.getName()) .reduce(new ParallelVBReduce()); // feed new centroids back into next iteration DataSet<CompoundVector> finlparamSet = loop.closeWith(newparamSet); parameterPrior = finlparamSet.collect().get(0); this.svb.updateNaturalParameterPosteriors(parameterPrior); this.svb.updateNaturalParameterPrior(parameterPrior); if(timeLimit == -1) this.globalELBO = ((ConvergenceELBO)loop.getAggregators().getConvergenceCriterion()).getELBO(); else this.globalELBO = ((ConvergenceELBObyTime)loop.getAggregators().getConvergenceCriterion()).getELBO(); this.svb.applyTransition(); }catch(Exception ex){ System.out.println(ex.getMessage().toString()); ex.printStackTrace(); throw new RuntimeException(ex.getMessage()); } this.randomStart=false; return this.getLogMarginalProbability(); }
Example 18
Source File: KMeansForTest.java From Flink-CEPplus with Apache License 2.0 | 4 votes |
public static void main(String[] args) throws Exception { if (args.length < 3) { throw new IllegalArgumentException("Missing parameters"); } final String pointsData = args[0]; final String centersData = args[1]; final int numIterations = Integer.parseInt(args[2]); ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.getConfig().disableSysoutLogging(); // get input data DataSet<Point> points = env.fromElements(pointsData.split("\n")) .map(new TuplePointConverter()); DataSet<Centroid> centroids = env.fromElements(centersData.split("\n")) .map(new TupleCentroidConverter()); // set number of bulk iterations for KMeans algorithm IterativeDataSet<Centroid> loop = centroids.iterate(numIterations); DataSet<Centroid> newCentroids = points // compute closest centroid for each point .map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids") // count and sum point coordinates for each centroid (test pojo return type) .map(new CountAppender()) // !test if key expressions are working! .groupBy("field0").reduce(new CentroidAccumulator()) // compute new centroids from point counts and coordinate sums .map(new CentroidAverager()); // feed new centroids back into next iteration DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids); // test that custom data type collects are working finalCentroids.collect(); }
Example 19
Source File: UnionStaticDynamicIterationITCase.java From flink with Apache License 2.0 | 3 votes |
@Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Long> inputStatic = env.generateSequence(1, 4); DataSet<Long> inputIteration = env.generateSequence(1, 4); IterativeDataSet<Long> iteration = inputIteration.iterate(3); DataSet<Long> result = iteration.closeWith(inputStatic.union(inputStatic).union(iteration.union(iteration))); result.output(new LocalCollectionOutputFormat<Long>(this.result)); env.execute(); }
Example 20
Source File: UnionStaticDynamicIterationITCase.java From flink with Apache License 2.0 | 3 votes |
@Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Long> inputStatic = env.generateSequence(1, 4); DataSet<Long> inputIteration = env.generateSequence(1, 4); IterativeDataSet<Long> iteration = inputIteration.iterate(3); DataSet<Long> result = iteration.closeWith(inputStatic.union(inputStatic).union(iteration.union(iteration))); result.output(new LocalCollectionOutputFormat<Long>(this.result)); env.execute(); }