Java Code Examples for org.deeplearning4j.models.word2vec.VocabWord#markAsLabel()
The following examples show how to use
org.deeplearning4j.models.word2vec.VocabWord#markAsLabel() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: WordVectorSerializer.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * This method restores ParagraphVectors model previously saved with writeParagraphVectors() * * @return */ public static ParagraphVectors readParagraphVectors(File file) throws IOException { Word2Vec w2v = readWord2Vec(file); // and "convert" it to ParaVec model + optionally trying to restore labels information ParagraphVectors vectors = new ParagraphVectors.Builder(w2v.getConfiguration()) .vocabCache(w2v.getVocab()) .lookupTable(w2v.getLookupTable()) .resetModel(false) .build(); try (ZipFile zipFile = new ZipFile(file)) { // now we try to restore labels information ZipEntry labels = zipFile.getEntry("labels.txt"); if (labels != null) { InputStream stream = zipFile.getInputStream(labels); try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim())); if (word != null) { word.markAsLabel(true); } } } } } vectors.extractLabels(); return vectors; }
Example 2
Source File: DocumentSequenceConvertFunction.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public Sequence<VocabWord> call(LabelledDocument document) throws Exception { Sequence<VocabWord> sequence = new Sequence<>(); // get elements if (document.getReferencedContent() != null && !document.getReferencedContent().isEmpty()) { sequence.addElements(document.getReferencedContent()); } else { if (tokenizerFactory == null) instantiateTokenizerFactory(); List<String> tokens = tokenizerFactory.create(document.getContent()).getTokens(); for (String token : tokens) { if (token == null || token.isEmpty()) continue; VocabWord word = new VocabWord(1.0, token); sequence.addElement(word); } } // get labels for (String label : document.getLabels()) { if (label == null || label.isEmpty()) continue; VocabWord labelElement = new VocabWord(1.0, label); labelElement.markAsLabel(true); sequence.addSequenceLabel(labelElement); } return sequence; }
Example 3
Source File: WordVectorSerializerTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test @Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testParaVecSerialization1() throws Exception { VectorsConfiguration configuration = new VectorsConfiguration(); configuration.setIterations(14123); configuration.setLayersSize(156); INDArray syn0 = Nd4j.rand(100, configuration.getLayersSize()); INDArray syn1 = Nd4j.rand(100, configuration.getLayersSize()); AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build(); for (int i = 0; i < 100; i++) { VocabWord word = new VocabWord((float) i, "word_" + i); List<Integer> points = new ArrayList<>(); List<Byte> codes = new ArrayList<>(); int num = RandomUtils.nextInt(1, 20); for (int x = 0; x < num; x++) { points.add(RandomUtils.nextInt(1, 100000)); codes.add(RandomUtils.nextBytes(10)[0]); } if (RandomUtils.nextInt(0, 10) < 3) { word.markAsLabel(true); } word.setIndex(i); word.setPoints(points); word.setCodes(codes); cache.addToken(word); cache.addWordToIndex(i, word.getLabel()); } InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>() .vectorLength(configuration.getLayersSize()).cache(cache).build(); lookupTable.setSyn0(syn0); lookupTable.setSyn1(syn1); ParagraphVectors originalVectors = new ParagraphVectors.Builder(configuration).vocabCache(cache).lookupTable(lookupTable).build(); File tempFile = File.createTempFile("paravec", "tests"); tempFile.deleteOnExit(); WordVectorSerializer.writeParagraphVectors(originalVectors, tempFile); ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(tempFile); InMemoryLookupTable<VocabWord> restoredLookupTable = (InMemoryLookupTable<VocabWord>) restoredVectors.getLookupTable(); AbstractCache<VocabWord> restoredVocab = (AbstractCache<VocabWord>) restoredVectors.getVocab(); assertEquals(restoredLookupTable.getSyn0(), lookupTable.getSyn0()); assertEquals(restoredLookupTable.getSyn1(), lookupTable.getSyn1()); for (int i = 0; i < cache.numWords(); i++) { assertEquals(cache.elementAtIndex(i).isLabel(), restoredVocab.elementAtIndex(i).isLabel()); assertEquals(cache.wordAtIndex(i), restoredVocab.wordAtIndex(i)); assertEquals(cache.elementAtIndex(i).getElementFrequency(), restoredVocab.elementAtIndex(i).getElementFrequency(), 0.1f); List<Integer> originalPoints = cache.elementAtIndex(i).getPoints(); List<Integer> restoredPoints = restoredVocab.elementAtIndex(i).getPoints(); assertEquals(originalPoints.size(), restoredPoints.size()); for (int x = 0; x < originalPoints.size(); x++) { assertEquals(originalPoints.get(x), restoredPoints.get(x)); } List<Byte> originalCodes = cache.elementAtIndex(i).getCodes(); List<Byte> restoredCodes = restoredVocab.elementAtIndex(i).getCodes(); assertEquals(originalCodes.size(), restoredCodes.size()); for (int x = 0; x < originalCodes.size(); x++) { assertEquals(originalCodes.get(x), restoredCodes.get(x)); } } }