Java Code Examples for org.deeplearning4j.models.word2vec.wordstore.VocabCache#tokenFor()
The following examples show how to use
org.deeplearning4j.models.word2vec.wordstore.VocabCache#tokenFor() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TextPipelineTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testFilterMinWordAddVocab() throws Exception { JavaSparkContext sc = getContext(); JavaRDD<String> corpusRDD = getCorpusRDD(sc); Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap); JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize(); pipeline.updateAndReturnAccumulatorVal(tokenizedRDD); Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value(); pipeline.filterMinWordAddVocab(wordFreqCounter); VocabCache<VocabWord> vocabCache = pipeline.getVocabCache(); assertTrue(vocabCache != null); VocabWord redVocab = vocabCache.tokenFor("red"); VocabWord flowerVocab = vocabCache.tokenFor("flowers"); VocabWord worldVocab = vocabCache.tokenFor("world"); VocabWord strangeVocab = vocabCache.tokenFor("strange"); assertEquals(redVocab.getWord(), "red"); assertEquals(redVocab.getElementFrequency(), 1, 0); assertEquals(flowerVocab.getWord(), "flowers"); assertEquals(flowerVocab.getElementFrequency(), 1, 0); assertEquals(worldVocab.getWord(), "world"); assertEquals(worldVocab.getElementFrequency(), 1, 0); assertEquals(strangeVocab.getWord(), "strange"); assertEquals(strangeVocab.getElementFrequency(), 2, 0); sc.stop(); }
Example 2
Source File: TextPipelineTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testBuildVocabCache() throws Exception { JavaSparkContext sc = getContext(); JavaRDD<String> corpusRDD = getCorpusRDD(sc); Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap); pipeline.buildVocabCache(); VocabCache<VocabWord> vocabCache = pipeline.getVocabCache(); assertTrue(vocabCache != null); log.info("VocabWords: " + vocabCache.words()); assertEquals(5, vocabCache.numWords()); VocabWord redVocab = vocabCache.tokenFor("red"); VocabWord flowerVocab = vocabCache.tokenFor("flowers"); VocabWord worldVocab = vocabCache.tokenFor("world"); VocabWord strangeVocab = vocabCache.tokenFor("strange"); log.info("Red word: " + redVocab); log.info("Flower word: " + flowerVocab); log.info("World word: " + worldVocab); log.info("Strange word: " + strangeVocab); assertEquals(redVocab.getWord(), "red"); assertEquals(redVocab.getElementFrequency(), 1, 0); assertEquals(flowerVocab.getWord(), "flowers"); assertEquals(flowerVocab.getElementFrequency(), 1, 0); assertEquals(worldVocab.getWord(), "world"); assertEquals(worldVocab.getElementFrequency(), 1, 0); assertEquals(strangeVocab.getWord(), "strange"); assertEquals(strangeVocab.getElementFrequency(), 2, 0); sc.stop(); }
Example 3
Source File: TextPipelineTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testFirstIteration() throws Exception { JavaSparkContext sc = getContext(); JavaRDD<String> corpusRDD = getCorpusRDD(sc); // word2vec.setRemoveStop(false); Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap); pipeline.buildVocabCache(); pipeline.buildVocabWordListRDD(); VocabCache<VocabWord> vocabCache = pipeline.getVocabCache(); /* Huffman huffman = new Huffman(vocabCache.vocabWords()); huffman.build(); huffman.applyIndexes(vocabCache); */ VocabWord token = vocabCache.tokenFor("strange"); VocabWord word = vocabCache.wordFor("strange"); log.info("Strange token: " + token); log.info("Strange word: " + word); // Get total word count and put into word2vec variable map Map<String, Object> word2vecVarMap = word2vec.getWord2vecVarMap(); word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount()); double[] expTable = word2vec.getExpTable(); JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD(); JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD(); CountCumSum countCumSum = new CountCumSum(sentenceCountRDD); JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum(); JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD); Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap); Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable); Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator(); FirstIterationFunction firstIterationFunction = new FirstIterationFunction( word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache()); Iterator<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator); assertTrue(ret.hasNext()); }
Example 4
Source File: SparkSequenceVectorsTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testFrequenciesCount() throws Exception { JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic); SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>(); seqVec.getConfiguration().setTokenizerFactory(DefaultTokenizerFactory.class.getCanonicalName()); seqVec.getConfiguration().setElementsLearningAlgorithm("org.deeplearning4j.spark.models.sequencevectors.learning.elements.SparkSkipGram"); seqVec.setExporter(new SparkModelExporter<VocabWord>() { @Override public void export(JavaRDD<ExportContainer<VocabWord>> rdd) { rdd.foreach(new SparkWord2VecTest.TestFn()); } }); seqVec.fitSequences(sequences); Counter<Long> counter = seqVec.getCounter(); // element "0" should have frequency of 20 assertEquals(20, counter.getCount(0L), 1e-5); // elements 1 - 9 should have frequencies of 10 for (int e = 1; e < sequencesCyclic.get(0).getElements().size() - 1; e++) { assertEquals(10, counter.getCount(sequencesCyclic.get(0).getElementByIndex(e).getStorageId()), 1e-5); } VocabCache<ShallowSequenceElement> shallowVocab = seqVec.getShallowVocabCache(); assertEquals(10, shallowVocab.numWords()); ShallowSequenceElement zero = shallowVocab.tokenFor(0L); ShallowSequenceElement first = shallowVocab.tokenFor(1L); assertNotEquals(null, zero); assertEquals(20.0, zero.getElementFrequency(), 1e-5); assertEquals(0, zero.getIndex()); assertEquals(10.0, first.getElementFrequency(), 1e-5); }