Java Code Examples for org.apache.spark.api.java.JavaSparkContext#broadcast()
The following examples show how to use
org.apache.spark.api.java.JavaSparkContext#broadcast() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TextPipelineTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testWordFreqAccNotIdentifyingStopWords() throws Exception { JavaSparkContext sc = getContext(); // word2vec.setRemoveStop(false); JavaRDD<String> corpusRDD = getCorpusRDD(sc); Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vecNoStop.getTokenizerVarMap()); TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap); JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize(); pipeline.updateAndReturnAccumulatorVal(tokenizedRDD); Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value(); assertEquals(wordFreqCounter.getCount("is"), 1, 0); assertEquals(wordFreqCounter.getCount("this"), 1, 0); assertEquals(wordFreqCounter.getCount("are"), 1, 0); assertEquals(wordFreqCounter.getCount("a"), 1, 0); assertEquals(wordFreqCounter.getCount("strange"), 2, 0); assertEquals(wordFreqCounter.getCount("flowers"), 1, 0); assertEquals(wordFreqCounter.getCount("world"), 1, 0); assertEquals(wordFreqCounter.getCount("red"), 1, 0); sc.stop(); }
Example 2
Source File: GraknSparkMemory.java From grakn with GNU Affero General Public License v3.0 | 6 votes |
public GraknSparkMemory(final VertexProgram<?> vertexProgram, final Set<MapReduce> mapReducers, final JavaSparkContext sparkContext) { if (null != vertexProgram) { for (final MemoryComputeKey key : vertexProgram.getMemoryComputeKeys()) { this.memoryComputeKeys.put(key.getKey(), key); } } for (final MapReduce mapReduce : mapReducers) { this.memoryComputeKeys.put( mapReduce.getMemoryKey(), MemoryComputeKey.of(mapReduce.getMemoryKey(), Operator.assign, false, false)); } for (final MemoryComputeKey memoryComputeKey : this.memoryComputeKeys.values()) { this.sparkMemory.put( memoryComputeKey.getKey(), sparkContext.accumulator(ObjectWritable.empty(), memoryComputeKey.getKey(), new MemoryAccumulator<>(memoryComputeKey))); } this.broadcast = sparkContext.broadcast(Collections.emptyMap()); }
Example 3
Source File: BQSRPipelineSpark.java From gatk with BSD 3-Clause "New" or "Revised" License | 6 votes |
@Override protected void runTool(final JavaSparkContext ctx) { String referenceFileName = addReferenceFilesForSpark(ctx, referenceArguments.getReferencePath()); List<String> localKnownSitesFilePaths = addVCFsForSpark(ctx, knownVariants); //Should this get the getUnfilteredReads? getReads will merge default and command line filters. //but the code below uses other filters for other parts of the pipeline that do not honor //the commandline. final JavaRDD<GATKRead> initialReads = getReads(); // The initial reads have already had the WellformedReadFilter applied to them, which // is all the filtering that ApplyBQSR wants. BQSR itself wants additional filtering // performed, so we do that here. //NOTE: this filter doesn't honor enabled/disabled commandline filters final ReadFilter bqsrReadFilter = ReadFilter.fromList(BaseRecalibrator.getBQSRSpecificReadFilterList(), getHeaderForReads()); final JavaRDD<GATKRead> filteredReadsForBQSR = initialReads.filter(read -> bqsrReadFilter.test(read)); JavaPairRDD<GATKRead, Iterable<GATKVariant>> readsWithVariants = JoinReadsWithVariants.join(filteredReadsForBQSR, localKnownSitesFilePaths); //note: we use the reference dictionary from the reads themselves. final RecalibrationReport bqsrReport = BaseRecalibratorSparkFn.apply(readsWithVariants, getHeaderForReads(), referenceFileName, bqsrArgs); final Broadcast<RecalibrationReport> reportBroadcast = ctx.broadcast(bqsrReport); final JavaRDD<GATKRead> finalReads = ApplyBQSRSparkFn.apply(initialReads, reportBroadcast, getHeaderForReads(), applyBqsrArgs.toApplyBQSRArgumentCollection(bqsrArgs)); writeReads(ctx, output, finalReads); }
Example 4
Source File: SvDiscoveryInputMetaData.java From gatk with BSD 3-Clause "New" or "Revised" License | 6 votes |
public SvDiscoveryInputMetaData(final JavaSparkContext ctx, final DiscoverVariantsFromContigAlignmentsSparkArgumentCollection discoverStageArgs, final String nonCanonicalChromosomeNamesFile, final String outputPath, final ReadMetadata readMetadata, final List<SVInterval> assembledIntervals, final PairedStrandedIntervalTree<EvidenceTargetLink> evidenceTargetLinks, final Broadcast<SVIntervalTree<VariantContext>> cnvCallsBroadcast, final SAMFileHeader headerForReads, final ReferenceMultiSparkSource reference, final Set<VCFHeaderLine> defaultToolVCFHeaderLines, final Logger toolLogger) { final SAMSequenceDictionary sequenceDictionary = headerForReads.getSequenceDictionary(); final Broadcast<Set<String>> canonicalChromosomesBroadcast = ctx.broadcast(SVUtils.getCanonicalChromosomes(nonCanonicalChromosomeNamesFile, sequenceDictionary)); final String sampleId = SVUtils.getSampleId(headerForReads); this.referenceData = new ReferenceData(canonicalChromosomesBroadcast, ctx.broadcast(reference), ctx.broadcast(sequenceDictionary)); this.sampleSpecificData = new SampleSpecificData(sampleId, cnvCallsBroadcast, assembledIntervals, evidenceTargetLinks, readMetadata, ctx.broadcast(headerForReads)); this.discoverStageArgs = discoverStageArgs; this.outputPath = outputPath; this.defaultToolVCFHeaderLines = defaultToolVCFHeaderLines; this.toolLogger = toolLogger; }
Example 5
Source File: ApplyBQSRSpark.java From gatk with BSD 3-Clause "New" or "Revised" License | 5 votes |
@Override protected void runTool(JavaSparkContext ctx) { JavaRDD<GATKRead> initialReads = getReads(); Broadcast<RecalibrationReport> recalibrationReportBroadCast = ctx.broadcast(new RecalibrationReport(BucketUtils.openFile(bqsrRecalFile))); final JavaRDD<GATKRead> recalibratedReads = ApplyBQSRSparkFn.apply(initialReads, recalibrationReportBroadCast, getHeaderForReads(), applyBQSRArgs); writeReads(ctx, output, recalibratedReads); }
Example 6
Source File: SparkMemory.java From tinkerpop with Apache License 2.0 | 5 votes |
protected void broadcastMemory(final JavaSparkContext sparkContext) { this.broadcast.destroy(true); // do we need to block? final Map<String, Object> toBroadcast = new HashMap<>(); this.sparkMemory.forEach((key, object) -> { if (!object.value().isEmpty() && this.memoryComputeKeys.get(key).isBroadcast()) toBroadcast.put(key, object.value()); }); this.broadcast = sparkContext.broadcast(toBroadcast); }
Example 7
Source File: PSScorerTest.java From gatk with BSD 3-Clause "New" or "Revised" License | 5 votes |
@Test(dataProvider = "mapPairs", groups = "spark") public void testMapGroupedReadsToTax(final int readLength, final List<Integer> NM1, final List<Integer> NM2, final List<Integer> clip1, final List<Integer> clip2, final List<Integer> insert1, final List<Integer> insert2, final List<Integer> delete1, final List<Integer> delete2, final List<String> contig1, final List<String> contig2, final List<Integer> truthTax) { final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); final Broadcast<PSTaxonomyDatabase> taxonomyDatabaseBroadcast = ctx.broadcast(taxonomyDatabase); //Test with alternate alignments assigned to the XA tag final List<Iterable<GATKRead>> readListXA = new ArrayList<>(); readListXA.add(generateReadPair(readLength, NM1, NM2, clip1, clip2, insert1, insert2, delete1, delete2, contig1, contig2, "XA")); final JavaRDD<Iterable<GATKRead>> pairsXA = ctx.parallelize(readListXA); final JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> resultXA = PSScorer.mapGroupedReadsToTax(pairsXA, MIN_IDENT, IDENT_MARGIN, taxonomyDatabaseBroadcast); final PSPathogenAlignmentHit infoXA = resultXA.first()._2; Assert.assertNotNull(infoXA); Assert.assertEquals(infoXA.taxIDs.size(), truthTax.size()); Assert.assertTrue(infoXA.taxIDs.containsAll(truthTax)); Assert.assertEquals(infoXA.numMates, 2); //Test SA tag final List<Iterable<GATKRead>> readListSA = new ArrayList<>(); readListSA.add(generateReadPair(readLength, NM1, NM2, clip1, clip2, insert1, insert2, delete1, delete2, contig1, contig2, "SA")); final JavaRDD<Iterable<GATKRead>> pairsSA = ctx.parallelize(readListSA); final JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> resultSA = PSScorer.mapGroupedReadsToTax(pairsSA, MIN_IDENT, IDENT_MARGIN, taxonomyDatabaseBroadcast); final PSPathogenAlignmentHit infoSA = resultSA.first()._2; Assert.assertNotNull(infoSA); Assert.assertEquals(infoSA.taxIDs.size(), truthTax.size()); Assert.assertTrue(infoSA.taxIDs.containsAll(truthTax)); Assert.assertEquals(infoSA.numMates, 2); }
Example 8
Source File: TextPipelineTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testFilterMinWordAddVocab() throws Exception { JavaSparkContext sc = getContext(); JavaRDD<String> corpusRDD = getCorpusRDD(sc); Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap); JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize(); pipeline.updateAndReturnAccumulatorVal(tokenizedRDD); Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value(); pipeline.filterMinWordAddVocab(wordFreqCounter); VocabCache<VocabWord> vocabCache = pipeline.getVocabCache(); assertTrue(vocabCache != null); VocabWord redVocab = vocabCache.tokenFor("red"); VocabWord flowerVocab = vocabCache.tokenFor("flowers"); VocabWord worldVocab = vocabCache.tokenFor("world"); VocabWord strangeVocab = vocabCache.tokenFor("strange"); assertEquals(redVocab.getWord(), "red"); assertEquals(redVocab.getElementFrequency(), 1, 0); assertEquals(flowerVocab.getWord(), "flowers"); assertEquals(flowerVocab.getElementFrequency(), 1, 0); assertEquals(worldVocab.getWord(), "world"); assertEquals(worldVocab.getElementFrequency(), 1, 0); assertEquals(strangeVocab.getWord(), "strange"); assertEquals(strangeVocab.getElementFrequency(), 2, 0); sc.stop(); }
Example 9
Source File: BoxClient.java From render with GNU General Public License v2.0 | 5 votes |
public void run(final SparkConf sparkConf) throws IOException { final JavaSparkContext sparkContext = new JavaSparkContext(sparkConf); LogUtilities.logSparkClusterInfo(sparkContext); setupForRun(); boolean foundBoxesRenderedForPriorRun = false; if (parameters.cleanUpPriorRun) { foundBoxesRenderedForPriorRun = cleanUpPriorRun(sparkContext); } final JavaRDD<BoxData> distributedBoxDataRdd = partitionBoxes(sparkContext, foundBoxesRenderedForPriorRun); final Broadcast<BoxGenerator> broadcastBoxGenerator = sparkContext.broadcast(boxGenerator); if (parameters.validateLabelsOnly) { validateLabelBoxes(sparkContext, distributedBoxDataRdd); } else { for (int level = 0; level <= parameters.box.maxLevel; level++) { renderBoxesForLevel(level, distributedBoxDataRdd, broadcastBoxGenerator); } } if (parameters.box.isOverviewNeeded() && (! parameters.explainPlan) && (! parameters.validateLabelsOnly)) { renderOverviewImages(sparkContext, broadcastBoxGenerator); } LogUtilities.logSparkClusterInfo(sparkContext); // log cluster info again here to add run stats to driver log sparkContext.stop(); }
Example 10
Source File: ReadsSparkSink.java From gatk with BSD 3-Clause "New" or "Revised" License | 5 votes |
private static void writeReads( final JavaSparkContext ctx, final String outputFile, final GATKPath referencePathSpecifier, final JavaRDD<SAMRecord> reads, final SAMFileHeader header, final long sbiIndexGranularity, final WriteOption... writeOptions) throws IOException { Broadcast<SAMFileHeader> headerBroadcast = ctx.broadcast(header); final JavaRDD<SAMRecord> sortedReadsWithHeader = reads.map(read -> { read.setHeaderStrict(headerBroadcast.getValue()); return read; }); HtsjdkReadsRdd htsjdkReadsRdd = new HtsjdkReadsRdd(header, sortedReadsWithHeader); HtsjdkReadsRddStorage.makeDefault(ctx) .referenceSourcePath(referencePathSpecifier == null ? null : referencePathSpecifier.getRawInputString()) .sbiIndexGranularity(sbiIndexGranularity) .write(htsjdkReadsRdd, outputFile, writeOptions); }
Example 11
Source File: FindBreakpointEvidenceSpark.java From gatk with BSD 3-Clause "New" or "Revised" License | 5 votes |
/** * Transform all the reads for a supplied set of template names in each interval into FASTQ records * for each interval, and do something with the list of FASTQ records for each interval (like write it to a file). */ @VisibleForTesting static List<AlignedAssemblyOrExcuse> handleAssemblies( final JavaSparkContext ctx, final HopscotchUniqueMultiMap<String, Integer, QNameAndInterval> qNamesMultiMap, final JavaRDD<GATKRead> unfilteredReads, final SVReadFilter filter, final int nIntervals, final boolean includeMappingLocation, final LocalAssemblyHandler localAssemblyHandler ) { final int[] counts = new int[nIntervals]; for ( final QNameAndInterval qNameAndInterval : qNamesMultiMap ) { counts[qNameAndInterval.getIntervalId()] += 1; } final ComplexityPartitioner partitioner = new ComplexityPartitioner(counts); final Broadcast<HopscotchUniqueMultiMap<String, Integer, QNameAndInterval>> broadcastQNamesMultiMap = ctx.broadcast(qNamesMultiMap); final List<AlignedAssemblyOrExcuse> intervalDispositions = unfilteredReads .mapPartitionsToPair(readItr -> new ReadsForQNamesFinder(broadcastQNamesMultiMap.value(), nIntervals, includeMappingLocation, readItr, filter).iterator(), false) .combineByKey(x -> x, SVUtils::concatenateLists, SVUtils::concatenateLists, partitioner, false, null) .map(localAssemblyHandler::apply) .collect(); SparkUtils.destroyBroadcast(broadcastQNamesMultiMap, "QNames multi map"); BwaMemIndexCache.closeAllDistributedInstances(ctx); return intervalDispositions; }
Example 12
Source File: PSScorerTest.java From gatk with BSD 3-Clause "New" or "Revised" License | 5 votes |
@Test(dataProvider = "mapUnpaired", groups = "spark") public void testMapGroupedReadsToTaxUnpaired(final int readLength, final List<Integer> NM, final List<Integer> clip, final List<Integer> insert, final List<Integer> delete, final List<String> contig, final List<Integer> truthTax) { if (!(NM.size() == clip.size() && NM.size() == insert.size() && NM.size() == delete.size() && NM.size() == contig.size())) { throw new TestException("Input lists for read must be of uniform length"); } final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); final Broadcast<PSTaxonomyDatabase> taxonomyDatabaseBroadcast = ctx.broadcast(taxonomyDatabase); //Test with alternate alignments assigned to the XA tag final List<Iterable<GATKRead>> readListXA = new ArrayList<>(); readListXA.add(generateUnpairedRead(readLength, NM, clip, insert, delete, contig, "XA")); final JavaRDD<Iterable<GATKRead>> pairsXA = ctx.parallelize(readListXA); final JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> resultXA = PSScorer.mapGroupedReadsToTax(pairsXA, MIN_IDENT, IDENT_MARGIN, taxonomyDatabaseBroadcast); final PSPathogenAlignmentHit infoXA = resultXA.first()._2; Assert.assertNotNull(infoXA); Assert.assertEquals(infoXA.taxIDs.size(), truthTax.size()); Assert.assertTrue(infoXA.taxIDs.containsAll(truthTax)); Assert.assertEquals(infoXA.numMates, 1); //Test SA tag final List<Iterable<GATKRead>> readListSA = new ArrayList<>(); readListSA.add(generateUnpairedRead(readLength, NM, clip, insert, delete, contig, "SA")); final JavaRDD<Iterable<GATKRead>> pairsSA = ctx.parallelize(readListSA); final JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> resultSA = PSScorer.mapGroupedReadsToTax(pairsSA, MIN_IDENT, IDENT_MARGIN, taxonomyDatabaseBroadcast); final PSPathogenAlignmentHit infoSA = resultSA.first()._2; Assert.assertNotNull(infoSA); Assert.assertEquals(infoSA.taxIDs.size(), truthTax.size()); Assert.assertTrue(infoSA.taxIDs.containsAll(truthTax)); Assert.assertEquals(infoSA.numMates, 1); }
Example 13
Source File: SparkUtils.java From deeplearning4j with Apache License 2.0 | 5 votes |
public static Broadcast<byte[]> asByteArrayBroadcast(JavaSparkContext sc, INDArray array){ ByteArrayOutputStream baos = new ByteArrayOutputStream(); try { Nd4j.write(array, new DataOutputStream(baos)); } catch (IOException e){ throw new RuntimeException(e); //Should never happen } byte[] paramBytes = baos.toByteArray(); //See docs in EvaluationRunner for why we use byte[] instead of INDArray (thread locality etc) return sc.broadcast(paramBytes); }
Example 14
Source File: BroadcastHadoopConfigHolder.java From deeplearning4j with Apache License 2.0 | 5 votes |
public static Broadcast<SerializableHadoopConfig> get(JavaSparkContext sc){ if(config != null && (!config.isValid() || sc.startTime() != sparkContextStartTime) ){ config = null; } if(config != null){ return config; } synchronized (BroadcastHadoopConfigHolder.class){ if(config == null){ config = sc.broadcast(new SerializableHadoopConfig(sc.hadoopConfiguration())); sparkContextStartTime = sc.startTime(); } } return config; }
Example 15
Source File: PathSeqPipelineSpark.java From gatk with BSD 3-Clause "New" or "Revised" License | 4 votes |
@Override protected void runTool(final JavaSparkContext ctx) { filterArgs.doReadFilterArgumentWarnings(getCommandLineParser().getPluginDescriptor(GATKReadFilterPluginDescriptor.class), logger); SAMFileHeader header = PSUtils.checkAndClearHeaderSequences(getHeaderForReads(), filterArgs, logger); //Do not allow use of numReducers if (numReducers > 0) { throw new UserException.BadInput("Use --readsPerPartitionOutput instead of --num-reducers."); } //Filter final Tuple2<JavaRDD<GATKRead>, JavaRDD<GATKRead>> filterResult; final PSFilter filter = new PSFilter(ctx, filterArgs, header); try (final PSFilterLogger filterLogger = filterArgs.filterMetricsFileUri != null ? new PSFilterFileLogger(getMetricsFile(), filterArgs.filterMetricsFileUri) : new PSFilterEmptyLogger()) { final JavaRDD<GATKRead> inputReads = getReads(); filterResult = filter.doFilter(inputReads, filterLogger); } JavaRDD<GATKRead> pairedReads = filterResult._1; JavaRDD<GATKRead> unpairedReads = filterResult._2; //Counting forces an action on the RDDs to guarantee we're done with the Bwa image and kmer filter final long numPairedReads = pairedReads.count(); final long numUnpairedReads = unpairedReads.count(); final long numTotalReads = numPairedReads + numUnpairedReads; //Closes Bwa image, kmer filter, and metrics file if used //Note the host Bwa image before must be unloaded before trying to load the pathogen image filter.close(); //Rebalance partitions using the counts final int numPairedPartitions = 1 + (int) (numPairedReads / readsPerPartition); final int numUnpairedPartitions = 1 + (int) (numUnpairedReads / readsPerPartition); pairedReads = repartitionPairedReads(pairedReads, numPairedPartitions, numPairedReads); unpairedReads = unpairedReads.repartition(numUnpairedPartitions); //Bwa pathogen alignment final PSBwaAlignerSpark aligner = new PSBwaAlignerSpark(ctx, bwaArgs); PSBwaUtils.addReferenceSequencesToHeader(header, bwaArgs.microbeDictionary); final Broadcast<SAMFileHeader> headerBroadcast = ctx.broadcast(header); JavaRDD<GATKRead> alignedPairedReads = aligner.doBwaAlignment(pairedReads, true, headerBroadcast); JavaRDD<GATKRead> alignedUnpairedReads = aligner.doBwaAlignment(unpairedReads, false, headerBroadcast); //Cache this expensive result. Note serialization significantly reduces memory consumption. alignedPairedReads.persist(StorageLevel.MEMORY_AND_DISK_SER()); alignedUnpairedReads.persist(StorageLevel.MEMORY_AND_DISK_SER()); //Score pathogens final PSScorer scorer = new PSScorer(scoreArgs); final JavaRDD<GATKRead> readsFinal = scorer.scoreReads(ctx, alignedPairedReads, alignedUnpairedReads, header); //Clean up header header = PSBwaUtils.removeUnmappedHeaderSequences(header, readsFinal, logger); //Log read counts if (scoreArgs.scoreMetricsFileUri != null) { try (final PSScoreLogger scoreLogger = new PSScoreFileLogger(getMetricsFile(), scoreArgs.scoreMetricsFileUri)) { scoreLogger.logReadCounts(readsFinal); } } //Write reads to BAM, if specified if (outputPath != null) { try { //Reduce number of partitions since we previously went to ~5K reads per partition, which // is far too small for sharded output. final int numPartitions = Math.max(1, (int) (numTotalReads / readsPerPartitionOutput)); final JavaRDD<GATKRead> readsFinalRepartitioned = readsFinal.coalesce(numPartitions, false); ReadsSparkSink.writeReads(ctx, outputPath, null, readsFinalRepartitioned, header, shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, numPartitions, shardedPartsDir, true, splittingIndexGranularity); } catch (final IOException e) { throw new UserException.CouldNotCreateOutputFile(outputPath, "writing failed", e); } } aligner.close(); }
Example 16
Source File: ConceptMaps.java From bunsen with Apache License 2.0 | 4 votes |
@Override public Broadcast<BroadcastableMappings> broadcast(Map<String,String> conceptMapUriToVersion) { List<ConceptMap> mapsList = getMaps().collectAsList(); Map<String,ConceptMap> mapsToLoad = mapsList .stream() .filter(conceptMap -> conceptMap.getVersion().equals(conceptMapUriToVersion.get(conceptMap.getUrl()))) .collect(Collectors.toMap(ConceptMap::getUrl, Function.identity())); // Expand the concept maps to load and sort them so dependencies are before // their dependents in the list. List<String> sortedMapsToLoad = sortMapsToLoad(conceptMapUriToVersion.keySet(), mapsToLoad); // Since this is used to map from one system to another, we use only targets // that don't introduce inaccurate meanings. (For instance, we can't map // general condition code to a more specific type, since that is not // representative of the source data.) Dataset<Mapping> mappings = getMappings(conceptMapUriToVersion) .filter("equivalence in ('equivalent', 'equals', 'wider', 'subsumes')"); // Group mappings by their concept map URI Map<String, List<Mapping>> groupedMappings = mappings .collectAsList() .stream() .collect(Collectors.groupingBy(Mapping::getConceptMapUri)); Map<String, BroadcastableConceptMap> broadcastableMaps = new HashMap<>(); for (String conceptMapUri: sortedMapsToLoad) { ConceptMap map = mapsToLoad.get(conceptMapUri); Set<String> children = getMapChildren(map); List<BroadcastableConceptMap> childMaps = children.stream() .map(child -> broadcastableMaps.get(child)) .collect(Collectors.toList()); BroadcastableConceptMap broadcastableConceptMap = new BroadcastableConceptMap(conceptMapUri, groupedMappings.getOrDefault(conceptMapUri, Collections.emptyList()), childMaps); broadcastableMaps.put(conceptMapUri, broadcastableConceptMap); } JavaSparkContext ctx = new JavaSparkContext(getMaps() .sparkSession() .sparkContext()); return ctx.broadcast(new BroadcastableMappings(broadcastableMaps)); }
Example 17
Source File: TextPipelineTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testFirstIteration() throws Exception { JavaSparkContext sc = getContext(); JavaRDD<String> corpusRDD = getCorpusRDD(sc); // word2vec.setRemoveStop(false); Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap); pipeline.buildVocabCache(); pipeline.buildVocabWordListRDD(); VocabCache<VocabWord> vocabCache = pipeline.getVocabCache(); /* Huffman huffman = new Huffman(vocabCache.vocabWords()); huffman.build(); huffman.applyIndexes(vocabCache); */ VocabWord token = vocabCache.tokenFor("strange"); VocabWord word = vocabCache.wordFor("strange"); log.info("Strange token: " + token); log.info("Strange word: " + word); // Get total word count and put into word2vec variable map Map<String, Object> word2vecVarMap = word2vec.getWord2vecVarMap(); word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount()); double[] expTable = word2vec.getExpTable(); JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD(); JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD(); CountCumSum countCumSum = new CountCumSum(sentenceCountRDD); JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum(); JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD); Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap); Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable); Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator(); FirstIterationFunction firstIterationFunction = new FirstIterationFunction( word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache()); Iterator<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator); assertTrue(ret.hasNext()); }
Example 18
Source File: ValueSetUdfs.java From bunsen with Apache License 2.0 | 3 votes |
/** * Pushes an "in_valueset" UDF that uses the given {@link BroadcastableValueSets} for its content. * * @param spark the spark session * @param valueSets the valuesets to use in the UDF */ public static synchronized void pushUdf(SparkSession spark, BroadcastableValueSets valueSets) { JavaSparkContext ctx = new JavaSparkContext(spark.sparkContext()); Broadcast<BroadcastableValueSets> broadcast = ctx.broadcast(valueSets); pushUdf(spark, broadcast); }
Example 19
Source File: ExtractOriginalAlignmentRecordsByNameSpark.java From gatk with BSD 3-Clause "New" or "Revised" License | 3 votes |
@Override protected void runTool( final JavaSparkContext ctx ) { final Broadcast<Set<String>> namesToLookForBroadcast = ctx.broadcast(parseReadNames()); final Function<GATKRead, Boolean> predicate = getGatkReadBooleanFunction(namesToLookForBroadcast, invertFilter); final JavaRDD<GATKRead> reads = getUnfilteredReads().filter(predicate).cache(); writeReads(ctx, outputSAM, reads, getHeaderForReads(), false); logger.info("Found " + reads.count() + " alignment records for " + namesToLookForBroadcast.getValue().size() + " unique read names."); }
Example 20
Source File: AbstractValueSets.java From bunsen with Apache License 2.0 | 3 votes |
/** * Returns a dataset with the values for each element in the map of uri to version. * * @param uriToVersion a map of value set URI to the version to load * @return a dataset of values for the given URIs and versions. */ public Dataset<Value> getValues(Map<String,String> uriToVersion) { JavaSparkContext context = new JavaSparkContext(this.spark.sparkContext()); Broadcast<Map<String,String>> broadcastUrisToVersion = context.broadcast(uriToVersion); return this.values.filter((FilterFunction<Value>) value -> { String latestVersion = broadcastUrisToVersion.getValue().get(value.getValueSetUri()); return latestVersion != null && latestVersion.equals(value.getValueSetVersion()); }); }