org.deeplearning4j.models.word2vec.wordstore.VocabCache Java Examples
The following examples show how to use
org.deeplearning4j.models.word2vec.wordstore.VocabCache.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SparkWord2VecTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test @Ignore("AB 2019/05/21 - Failing - Issue #7657") public void testStringsTokenization1() throws Exception { JavaRDD<String> rddSentences = sc.parallelize(sentences); SparkWord2Vec word2Vec = new SparkWord2Vec(); word2Vec.getConfiguration().setTokenizerFactory(DefaultTokenizerFactory.class.getCanonicalName()); word2Vec.getConfiguration().setElementsLearningAlgorithm("org.deeplearning4j.spark.models.sequencevectors.learning.elements.SparkSkipGram"); word2Vec.setExporter(new SparkModelExporter<VocabWord>() { @Override public void export(JavaRDD<ExportContainer<VocabWord>> rdd) { rdd.foreach(new TestFn()); } }); word2Vec.fitSentences(rddSentences); VocabCache<ShallowSequenceElement> vocabCache = word2Vec.getShallowVocabCache(); assertNotEquals(null, vocabCache); assertEquals(9, vocabCache.numWords()); assertEquals(2.0, vocabCache.wordFor(SequenceElement.getLongHash("one")).getElementFrequency(), 1e-5); assertEquals(1.0, vocabCache.wordFor(SequenceElement.getLongHash("two")).getElementFrequency(), 1e-5); }
Example #2
Source File: SparkSequenceVectors.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * This method builds shadow vocabulary and huffman tree * * @param counter * @return */ protected VocabCache<ShallowSequenceElement> buildShallowVocabCache(Counter<Long> counter) { // TODO: need simplified cache here, that will operate on Long instead of string labels VocabCache<ShallowSequenceElement> vocabCache = new AbstractCache<>(); for (Long id : counter.keySet()) { ShallowSequenceElement shallowElement = new ShallowSequenceElement(counter.getCount(id), id); vocabCache.addToken(shallowElement); } // building huffman tree Huffman huffman = new Huffman(vocabCache.vocabWords()); huffman.build(); huffman.applyIndexes(vocabCache); return vocabCache; }
Example #3
Source File: VocabHolder.java From deeplearning4j with Apache License 2.0 | 6 votes |
public INDArray getSyn0Vector(Integer wordIndex, VocabCache<VocabWord> vocabCache) { if (!workers.contains(Thread.currentThread().getId())) workers.add(Thread.currentThread().getId()); VocabWord word = vocabCache.elementAtIndex(wordIndex); if (!indexSyn0VecMap.containsKey(word)) { synchronized (this) { if (!indexSyn0VecMap.containsKey(word)) { indexSyn0VecMap.put(word, getRandomSyn0Vec(vectorLength.get(), wordIndex)); } } } return indexSyn0VecMap.get(word); }
Example #4
Source File: FastText.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public VocabCache vocab() { if (modelVectorsLoaded) { vocabCache = word2Vec.vocab(); } else { if (!modelLoaded) throw new IllegalStateException("Load model before calling vocab()"); if (vocabCache == null) { vocabCache = new AbstractCache(); } List<String> words = fastTextImpl.getWords(); for (int i = 0; i < words.size(); ++i) { vocabCache.addWordToIndex(i, words.get(i)); VocabWord word = new VocabWord(); word.setWord(words.get(i)); vocabCache.addToken(word); } } return vocabCache; }
Example #5
Source File: DM.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) { this.vocabCache = vocabCache; this.lookupTable = lookupTable; this.configuration = configuration; cbow.configure(vocabCache, lookupTable, configuration); this.window = configuration.getWindow(); this.useAdaGrad = configuration.isUseAdaGrad(); this.negative = configuration.getNegative(); this.sampling = configuration.getSampling(); this.syn0 = ((InMemoryLookupTable<T>) lookupTable).getSyn0(); this.syn1 = ((InMemoryLookupTable<T>) lookupTable).getSyn1(); this.syn1Neg = ((InMemoryLookupTable<T>) lookupTable).getSyn1Neg(); this.expTable = ((InMemoryLookupTable<T>) lookupTable).getExpTable(); this.table = ((InMemoryLookupTable<T>) lookupTable).getTable(); }
Example #6
Source File: WordVectorSerializer.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * This method saves specified SequenceVectors model to target OutputStream * * @param vectors SequenceVectors model * @param factory SequenceElementFactory implementation for your objects * @param stream Target output stream * @param <T> */ public static <T extends SequenceElement> void writeSequenceVectors(@NonNull SequenceVectors<T> vectors, @NonNull SequenceElementFactory<T> factory, @NonNull OutputStream stream) throws IOException { WeightLookupTable<T> lookupTable = vectors.getLookupTable(); VocabCache<T> vocabCache = vectors.getVocab(); try (PrintWriter writer = new PrintWriter(new BufferedWriter(new OutputStreamWriter(stream, StandardCharsets.UTF_8)))) { // at first line we save VectorsConfiguration writer.write(vectors.getConfiguration().toEncodedJson()); // now we have elements one by one for (int x = 0; x < vocabCache.numWords(); x++) { T element = vocabCache.elementAtIndex(x); String json = factory.serialize(element); INDArray d = Nd4j.create(1); double[] vector = lookupTable.vector(element.getLabel()).dup().data().asDouble(); ElementPair pair = new ElementPair(json, vector); writer.println(pair.toEncodedJson()); writer.flush(); } } }
Example #7
Source File: TSNEVisualizationExample.java From Java-Deep-Learning-Cookbook with MIT License | 6 votes |
public static void main(String[] args) throws IOException { Nd4j.setDataType(DataBuffer.Type.DOUBLE); List<String> cacheList = new ArrayList<>(); File file = new File("words.txt"); String outputFile = "tsne-standard-coords.csv"; Pair<InMemoryLookupTable,VocabCache> vectors = WordVectorSerializer.loadTxt(file); VocabCache cache = vectors.getSecond(); INDArray weights = vectors.getFirst().getSyn0(); for(int i=0;i<cache.numWords();i++){ cacheList.add(cache.wordAtIndex(i)); } BarnesHutTsne tsne = new BarnesHutTsne.Builder() .setMaxIter(100) .theta(0.5) .normalize(false) .learningRate(500) .useAdaGrad(false) .build(); tsne.fit(weights); tsne.saveAsFile(cacheList,outputFile); }
Example #8
Source File: TSNEVisualizationExample.java From Java-Deep-Learning-Cookbook with MIT License | 6 votes |
public static void main(String[] args) throws IOException { Nd4j.setDataType(DataBuffer.Type.DOUBLE); List<String> cacheList = new ArrayList<>(); File file = new File("words.txt"); String outputFile = "tsne-standard-coords.csv"; Pair<InMemoryLookupTable,VocabCache> vectors = WordVectorSerializer.loadTxt(file); VocabCache cache = vectors.getSecond(); INDArray weights = vectors.getFirst().getSyn0(); for(int i=0;i<cache.numWords();i++){ cacheList.add(cache.wordAtIndex(i)); } BarnesHutTsne tsne = new BarnesHutTsne.Builder() .setMaxIter(100) .theta(0.5) .normalize(false) .learningRate(500) .useAdaGrad(false) .build(); tsne.fit(weights); tsne.saveAsFile(cacheList,outputFile); }
Example #9
Source File: VocabHolder.java From deeplearning4j with Apache License 2.0 | 5 votes |
public Iterable<Map.Entry<VocabWord, INDArray>> getSplit(VocabCache<VocabWord> vocabCache) { Set<Map.Entry<VocabWord, INDArray>> set = new HashSet<>(); int cnt = 0; for (Map.Entry<VocabWord, INDArray> entry : indexSyn0VecMap.entrySet()) { set.add(entry); cnt++; if (cnt > 10) break; } System.out.println("Returning set: " + set.size()); return set; }
Example #10
Source File: CBOW.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) { this.vocabCache = vocabCache; this.lookupTable = lookupTable; this.configuration = configuration; this.window = configuration.getWindow(); this.useAdaGrad = configuration.isUseAdaGrad(); this.negative = configuration.getNegative(); this.sampling = configuration.getSampling(); if (configuration.getNegative() > 0) { if (((InMemoryLookupTable<T>) lookupTable).getSyn1Neg() == null) { logger.info("Initializing syn1Neg..."); ((InMemoryLookupTable<T>) lookupTable).setUseHS(configuration.isUseHierarchicSoftmax()); ((InMemoryLookupTable<T>) lookupTable).setNegative(configuration.getNegative()); ((InMemoryLookupTable<T>) lookupTable).resetWeights(false); } } this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn0()); this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1()); this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getSyn1Neg()); //this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable<T>) lookupTable).getExpTable())); this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable<T>) lookupTable).getExpTable(), new long[]{((InMemoryLookupTable<T>) lookupTable).getExpTable().length}, syn0.get().dataType())); this.table = new DeviceLocalNDArray(((InMemoryLookupTable<T>) lookupTable).getTable()); this.variableWindows = configuration.getVariableWindows(); }
Example #11
Source File: NegativeHolder.java From deeplearning4j with Apache License 2.0 | 5 votes |
public synchronized void initHolder(@NonNull VocabCache<VocabWord> vocabCache, double[] expTable, int layerSize) { if (!wasInit.get()) { this.vocab = vocabCache; this.syn1Neg = Nd4j.zeros(vocabCache.numWords(), layerSize); makeTable(Math.max(expTable.length, 100000), 0.75); wasInit.set(true); } }
Example #12
Source File: SecondIterationFunction.java From deeplearning4j with Apache License 2.0 | 5 votes |
public SecondIterationFunction(Broadcast<Map<String, Object>> word2vecVarMapBroadcast, Broadcast<double[]> expTableBroadcast, Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast) { Map<String, Object> word2vecVarMap = word2vecVarMapBroadcast.getValue(); this.expTable = expTableBroadcast.getValue(); this.vectorLength = (int) word2vecVarMap.get("vectorLength"); this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad"); this.negative = (double) word2vecVarMap.get("negative"); this.window = (int) word2vecVarMap.get("window"); this.alpha = (double) word2vecVarMap.get("alpha"); this.minAlpha = (double) word2vecVarMap.get("minAlpha"); this.totalWordCount = (long) word2vecVarMap.get("totalWordCount"); this.seed = (long) word2vecVarMap.get("seed"); this.maxExp = (int) word2vecVarMap.get("maxExp"); this.iterations = (int) word2vecVarMap.get("iterations"); this.batchSize = (int) word2vecVarMap.get("batchSize"); // this.indexSyn0VecMap = new HashMap<>(); // this.pointSyn1VecMap = new HashMap<>(); this.vocab = vocabCacheBroadcast.getValue(); if (this.vocab == null) throw new RuntimeException("VocabCache is null"); }
Example #13
Source File: InMemoryVocabStoreTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testStorePut() { VocabCache<VocabWord> cache = new InMemoryLookupCache(); assertFalse(cache.containsWord("hello")); cache.addWordToIndex(0, "hello"); assertTrue(cache.containsWord("hello")); assertEquals(1, cache.numWords()); assertEquals("hello", cache.wordAtIndex(0)); }
Example #14
Source File: InMemoryLookupCache.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void importVocabulary(VocabCache<VocabWord> vocabCache) { for (VocabWord word : vocabCache.vocabWords()) { if (vocabs.containsKey(word.getLabel())) { wordFrequencies.incrementCount(word.getLabel(), (float) word.getElementFrequency()); } else { tokens.put(word.getLabel(), word); vocabs.put(word.getLabel(), word); wordFrequencies.incrementCount(word.getLabel(), (float) word.getElementFrequency()); } totalWordOccurrences.addAndGet((long) word.getElementFrequency()); } }
Example #15
Source File: AbstractCache.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * This method imports all elements from VocabCache passed as argument * If element already exists, * * @param vocabCache */ public void importVocabulary(@NonNull VocabCache<T> vocabCache) { AtomicBoolean added = new AtomicBoolean(false); for (T element : vocabCache.vocabWords()) { if (this.addToken(element)) added.set(true); } //logger.info("Current state: {}; Adding value: {}", this.documentsCounter.get(), vocabCache.totalNumberOfDocs()); if (added.get()) this.documentsCounter.addAndGet(vocabCache.totalNumberOfDocs()); }
Example #16
Source File: Huffman.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * This method updates VocabCache and all it's elements with Huffman indexes * Please note: it should be the same VocabCache as was used for Huffman tree initialization * * @param cache VocabCache to be updated. */ public void applyIndexes(VocabCache<? extends SequenceElement> cache) { if (!buildTrigger) build(); for (int a = 0; a < words.size(); a++) { if (words.get(a).getLabel() != null) { cache.addWordToIndex(a, words.get(a).getLabel()); } else { cache.addWordToIndex(a, words.get(a).getStorageId()); } words.get(a).setIndex(a); } }
Example #17
Source File: WordVectorSerializer.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * This method saves paragraph vectors to the given output stream. * * @deprecated Use {@link #writeParagraphVectors(ParagraphVectors, OutputStream)} */ @Deprecated public static void writeWordVectors(ParagraphVectors vectors, OutputStream stream) { try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(stream, StandardCharsets.UTF_8))) { /* This method acts similary to w2v csv serialization, except of additional tag for labels */ VocabCache<VocabWord> vocabCache = vectors.getVocab(); for (VocabWord word : vocabCache.vocabWords()) { StringBuilder builder = new StringBuilder(); builder.append(word.isLabel() ? "L" : "E").append(" "); builder.append(word.getLabel().replaceAll(" ", WHITESPACE_REPLACEMENT)).append(" "); INDArray vector = vectors.getWordVectorMatrix(word.getLabel()); for (int j = 0; j < vector.length(); j++) { builder.append(vector.getDouble(j)); if (j < vector.length() - 1) { builder.append(" "); } } writer.write(builder.append("\n").toString()); } } catch (Exception e) { throw new RuntimeException(e); } }
Example #18
Source File: WordVectorSerializer.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * This method reads vocab cache from provided InputStream. * Please note: it reads only vocab content, so it's suitable mostly for BagOfWords/TF-IDF vectorizers * * @param stream * @return * @throws IOException */ public static VocabCache<VocabWord> readVocabCache(@NonNull InputStream stream) throws IOException { val vocabCache = new AbstractCache.Builder<VocabWord>().build(); val factory = new VocabWordFactory(); boolean firstLine = true; long totalWordOcc = -1L; try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { // try to treat first line as header with 3 digits if (firstLine) { firstLine = false; val split = line.split("\\ "); if (split.length != 3) continue; try { vocabCache.setTotalDocCount(Long.valueOf(split[1])); totalWordOcc = Long.valueOf(split[2]); continue; } catch (NumberFormatException e) { // no-op } } val word = factory.deserialize(line); vocabCache.addToken(word); vocabCache.addWordToIndex(word.getIndex(), word.getLabel()); } } if (totalWordOcc >= 0) vocabCache.setTotalWordOccurences(totalWordOcc); return vocabCache; }
Example #19
Source File: WordVectorSerializer.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * This method loads Word2Vec model from csv file * * @param inputStream input stream * @return Word2Vec model */ public static Word2Vec readAsCsv(@NonNull InputStream inputStream) { VectorsConfiguration configuration = new VectorsConfiguration(); // let's try to load this file as csv file try { log.debug("Trying CSV model restoration..."); Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(inputStream); Word2Vec.Builder builder = new Word2Vec .Builder() .lookupTable(pair.getFirst()) .useAdaGrad(false) .vocabCache(pair.getSecond()) .layerSize(pair.getFirst().layerSize()) // we don't use hs here, because model is incomplete .useHierarchicSoftmax(false) .resetModel(false); TokenizerFactory factory = getTokenizerFactory(configuration); if (factory != null) { builder.tokenizerFactory(factory); } return builder.build(); } catch (Exception ex) { throw new RuntimeException("Unable to load model in CSV format"); } }
Example #20
Source File: FirstIterationFunction.java From deeplearning4j with Apache License 2.0 | 5 votes |
public FirstIterationFunction(Broadcast<Map<String, Object>> word2vecVarMapBroadcast, Broadcast<double[]> expTableBroadcast, Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast) { Map<String, Object> word2vecVarMap = word2vecVarMapBroadcast.getValue(); this.expTable = expTableBroadcast.getValue(); this.vectorLength = (int) word2vecVarMap.get("vectorLength"); this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad"); this.negative = (double) word2vecVarMap.get("negative"); this.window = (int) word2vecVarMap.get("window"); this.alpha = (double) word2vecVarMap.get("alpha"); this.minAlpha = (double) word2vecVarMap.get("minAlpha"); this.totalWordCount = (long) word2vecVarMap.get("totalWordCount"); this.seed = (long) word2vecVarMap.get("seed"); this.maxExp = (int) word2vecVarMap.get("maxExp"); this.iterations = (int) word2vecVarMap.get("iterations"); this.batchSize = (int) word2vecVarMap.get("batchSize"); this.indexSyn0VecMap = new HashMap<>(); this.pointSyn1VecMap = new HashMap<>(); this.vocab = vocabCacheBroadcast.getValue(); if (this.vocab == null) throw new RuntimeException("VocabCache is null"); if (negative > 0) { negativeHolder = NegativeHolder.getInstance(); negativeHolder.initHolder(vocab, expTable, this.vectorLength); } }
Example #21
Source File: TextPipeline.java From deeplearning4j with Apache License 2.0 | 5 votes |
public Broadcast<VocabCache<VocabWord>> getBroadCastVocabCache() throws IllegalStateException { if (vocabCache.numWords() > 0) { return vocabCacheBroadcast; } else { throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline."); } }
Example #22
Source File: TextPipeline.java From deeplearning4j with Apache License 2.0 | 5 votes |
public VocabCache<VocabWord> getVocabCache() throws IllegalStateException { if (vocabCache != null && vocabCache.numWords() > 0) { return vocabCache; } else { throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline."); } }
Example #23
Source File: TextPipelineTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testFilterMinWordAddVocab() throws Exception { JavaSparkContext sc = getContext(); JavaRDD<String> corpusRDD = getCorpusRDD(sc); Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap); JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize(); pipeline.updateAndReturnAccumulatorVal(tokenizedRDD); Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value(); pipeline.filterMinWordAddVocab(wordFreqCounter); VocabCache<VocabWord> vocabCache = pipeline.getVocabCache(); assertTrue(vocabCache != null); VocabWord redVocab = vocabCache.tokenFor("red"); VocabWord flowerVocab = vocabCache.tokenFor("flowers"); VocabWord worldVocab = vocabCache.tokenFor("world"); VocabWord strangeVocab = vocabCache.tokenFor("strange"); assertEquals(redVocab.getWord(), "red"); assertEquals(redVocab.getElementFrequency(), 1, 0); assertEquals(flowerVocab.getWord(), "flowers"); assertEquals(flowerVocab.getElementFrequency(), 1, 0); assertEquals(worldVocab.getWord(), "world"); assertEquals(worldVocab.getElementFrequency(), 1, 0); assertEquals(strangeVocab.getWord(), "strange"); assertEquals(strangeVocab.getElementFrequency(), 2, 0); sc.stop(); }
Example #24
Source File: TextPipelineTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testBuildVocabCache() throws Exception { JavaSparkContext sc = getContext(); JavaRDD<String> corpusRDD = getCorpusRDD(sc); Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap); pipeline.buildVocabCache(); VocabCache<VocabWord> vocabCache = pipeline.getVocabCache(); assertTrue(vocabCache != null); log.info("VocabWords: " + vocabCache.words()); assertEquals(5, vocabCache.numWords()); VocabWord redVocab = vocabCache.tokenFor("red"); VocabWord flowerVocab = vocabCache.tokenFor("flowers"); VocabWord worldVocab = vocabCache.tokenFor("world"); VocabWord strangeVocab = vocabCache.tokenFor("strange"); log.info("Red word: " + redVocab); log.info("Flower word: " + flowerVocab); log.info("World word: " + worldVocab); log.info("Strange word: " + strangeVocab); assertEquals(redVocab.getWord(), "red"); assertEquals(redVocab.getElementFrequency(), 1, 0); assertEquals(flowerVocab.getWord(), "flowers"); assertEquals(flowerVocab.getElementFrequency(), 1, 0); assertEquals(worldVocab.getWord(), "world"); assertEquals(worldVocab.getElementFrequency(), 1, 0); assertEquals(strangeVocab.getWord(), "strange"); assertEquals(strangeVocab.getElementFrequency(), 2, 0); sc.stop(); }
Example #25
Source File: TextPipelineTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testSyn0AfterFirstIteration() throws Exception { JavaSparkContext sc = getContext(); JavaRDD<String> corpusRDD = getCorpusRDD(sc); // word2vec.setRemoveStop(false); Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap); pipeline.buildVocabCache(); pipeline.buildVocabWordListRDD(); VocabCache<VocabWord> vocabCache = pipeline.getVocabCache(); Huffman huffman = new Huffman(vocabCache.vocabWords()); huffman.build(); // Get total word count and put into word2vec variable map Map<String, Object> word2vecVarMap = word2vec.getWord2vecVarMap(); word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount()); double[] expTable = word2vec.getExpTable(); JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD(); JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD(); CountCumSum countCumSum = new CountCumSum(sentenceCountRDD); JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum(); JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD); Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap); Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable); FirstIterationFunction firstIterationFunction = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache()); JavaRDD<Pair<VocabWord, INDArray>> pointSyn0Vec = vocabWordListSentenceCumSumRDD .mapPartitions(firstIterationFunction).map(new MapToPairFunction()); }
Example #26
Source File: WordVectorSerializerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test @Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testIndexPersistence() throws Exception { File inputFile = Resources.asFile("big/raw_sentences.txt"); SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath()); // Split on white spaces in the line to get words TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100) .stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5) .iterate(iter).tokenizerFactory(t).build(); vec.fit(); VocabCache orig = vec.getVocab(); File tempFile = File.createTempFile("temp", "w2v"); tempFile.deleteOnExit(); WordVectorSerializer.writeWordVectors(vec, tempFile); WordVectors vec2 = WordVectorSerializer.loadTxtVectors(tempFile); VocabCache rest = vec2.vocab(); assertEquals(orig.totalNumberOfDocs(), rest.totalNumberOfDocs()); for (VocabWord word : vec.getVocab().vocabWords()) { INDArray array1 = vec.getWordVectorMatrix(word.getLabel()); INDArray array2 = vec2.getWordVectorMatrix(word.getLabel()); assertEquals(array1, array2); } }
Example #27
Source File: PartitionTrainingFunction.java From deeplearning4j with Apache License 2.0 | 5 votes |
public PartitionTrainingFunction(@NonNull Broadcast<VocabCache<ShallowSequenceElement>> vocabCacheBroadcast, @NonNull Broadcast<VectorsConfiguration> vectorsConfigurationBroadcast, @NonNull Broadcast<VoidConfiguration> paramServerConfigurationBroadcast) { this.vocabCacheBroadcast = vocabCacheBroadcast; this.configurationBroadcast = vectorsConfigurationBroadcast; this.paramServerConfigurationBroadcast = paramServerConfigurationBroadcast; }
Example #28
Source File: DistributedFunction.java From deeplearning4j with Apache License 2.0 | 5 votes |
public DistributedFunction(@NonNull Broadcast<VoidConfiguration> configurationBroadcast, @NonNull Broadcast<VectorsConfiguration> vectorsConfigurationBroadcast, @NonNull Broadcast<VocabCache<ShallowSequenceElement>> shallowVocabBroadcast) { this.configurationBroadcast = configurationBroadcast; this.vectorsConfigurationBroadcast = vectorsConfigurationBroadcast; this.shallowVocabBroadcast = shallowVocabBroadcast; }
Example #29
Source File: TrainingFunction.java From deeplearning4j with Apache License 2.0 | 5 votes |
public TrainingFunction(@NonNull Broadcast<VocabCache<ShallowSequenceElement>> vocabCacheBroadcast, @NonNull Broadcast<VectorsConfiguration> vectorsConfigurationBroadcast, @NonNull Broadcast<VoidConfiguration> paramServerConfigurationBroadcast) { this.vocabCacheBroadcast = vocabCacheBroadcast; this.configurationBroadcast = vectorsConfigurationBroadcast; this.paramServerConfigurationBroadcast = paramServerConfigurationBroadcast; }
Example #30
Source File: WordVectorSerializer.java From deeplearning4j with Apache License 2.0 | 5 votes |
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(@NonNull File file) { try (InputStream inputStream = fileStream(file)) { return loadTxt(inputStream); } catch (IOException readTestException) { throw new RuntimeException(readTestException); } }