org.nd4j.linalg.dataset.api.MultiDataSet Java Examples
The following examples show how to use
org.nd4j.linalg.dataset.api.MultiDataSet.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TrainUtil.java From FancyBing with GNU General Public License v3.0 | 6 votes |
public static double evaluate(Model model, int outputNum, MultiDataSetIterator testData, int topN, int batchSize) { log.info("Evaluate model...."); Evaluation clsEval = new Evaluation(createLabels(outputNum), topN); RegressionEvaluation valueRegEval1 = new RegressionEvaluation(1); int count = 0; long begin = 0; long consume = 0; while(testData.hasNext()){ MultiDataSet ds = testData.next(); begin = System.nanoTime(); INDArray[] output = ((ComputationGraph) model).output(false, ds.getFeatures()); consume += System.nanoTime() - begin; clsEval.eval(ds.getLabels(0), output[0]); valueRegEval1.eval(ds.getLabels(1), output[1]); count++; } String stats = clsEval.stats(); int pos = stats.indexOf("==="); stats = "\n" + stats.substring(pos); log.info(stats); log.info(valueRegEval1.stats()); testData.reset(); log.info("Evaluate time: " + consume + " count: " + (count * batchSize) + " average: " + ((float) consume/(count*batchSize)/1000)); return clsEval.accuracy(); }
Example #2
Source File: BaseSparkEarlyStoppingTrainer.java From deeplearning4j with Apache License 2.0 | 6 votes |
protected BaseSparkEarlyStoppingTrainer(JavaSparkContext sc, EarlyStoppingConfiguration<T> esConfig, T net, JavaRDD<DataSet> train, JavaRDD<MultiDataSet> trainMulti, EarlyStoppingListener<T> listener) { if ((esConfig.getEpochTerminationConditions() == null || esConfig.getEpochTerminationConditions().isEmpty()) && (esConfig.getIterationTerminationConditions() == null || esConfig.getIterationTerminationConditions().isEmpty())) { throw new IllegalArgumentException( "Cannot conduct early stopping without a termination condition (both Iteration " + "and Epoch termination conditions are null/empty)"); } this.sc = sc; this.esConfig = esConfig; this.net = net; this.train = train; this.trainMulti = trainMulti; this.listener = listener; }
Example #3
Source File: TestBertIterator.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testSentencePairFeaturizer() throws IOException { int minibatchSize = 2; TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize); BertIterator b = BertIterator.builder() .tokenizer(testPairHelper.getTokenizer()) .minibatchSize(minibatchSize) .padMinibatches(true) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128) .sentencePairProvider(testPairHelper.getPairSentenceProvider()) .prependToken("[CLS]") .appendToken("[SEP]") .build(); MultiDataSet mds = b.next(); INDArray[] featuresArr = mds.getFeatures(); INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays(); Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(testPairHelper.getSentencePairs()); assertEquals(p.getFirst().length, 2); assertEquals(featuresArr[0], p.getFirst()[0]); assertEquals(featuresArr[1], p.getFirst()[1]); assertEquals(featuresMaskArr[0], p.getSecond()[0]); }
Example #4
Source File: SparkAMDSI.java From deeplearning4j with Apache License 2.0 | 6 votes |
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue, boolean useWorkspace, DataSetCallback callback, Integer deviceId) { this(); if (queueSize < 2) queueSize = 2; this.callback = callback; this.buffer = queue; this.backedIterator = iterator; this.useWorkspaces = useWorkspace; this.prefetchSize = queueSize; this.workspaceId = "SAMDSI_ITER-" + java.util.UUID.randomUUID().toString(); this.deviceId = deviceId; if (iterator.resetSupported()) this.backedIterator.reset(); this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread()); context = TaskContext.get(); thread.setDaemon(true); thread.start(); }
Example #5
Source File: BatchAndExportMultiDataSetsFunction.java From deeplearning4j with Apache License 2.0 | 6 votes |
private String export(MultiDataSet dataSet, int partitionIdx, int outputCount) throws Exception { String filename = "mds_" + partitionIdx + jvmuid + "_" + outputCount + ".bin"; URI uri = new URI(exportBaseDirectory + (exportBaseDirectory.endsWith("/") || exportBaseDirectory.endsWith("\\") ? "" : "/") + filename); Configuration c = conf == null ? DefaultHadoopConfig.get() : conf.getValue().getConfiguration(); FileSystem file = FileSystem.get(uri, c); try (FSDataOutputStream out = file.create(new Path(uri))) { dataSet.save(out); } return uri.toString(); }
Example #6
Source File: TestMultiDataSetIterator.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public MultiDataSet next(int num) { int end = curr + num; List<MultiDataSet> r = new ArrayList<>(); if (end >= list.size()) { end = list.size(); } for (; curr < end; curr++) { r.add(list.get(curr)); } MultiDataSet d = org.nd4j.linalg.dataset.MultiDataSet.merge(r); if (preProcessor != null) { preProcessor.preProcess(d); } return d; }
Example #7
Source File: IEvaluateMDSFlatMapFunction.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public Iterator<T[]> call(Iterator<MultiDataSet> dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { return Collections.emptyIterator(); } if (!dataSetIterator.hasNext()) { return Collections.emptyIterator(); } Future<IEvaluation[]> f = EvaluationRunner.getInstance().execute( evaluations, evalNumWorkers, evalBatchSize, null, dataSetIterator, true, json, params); IEvaluation[] result = f.get(); if(result == null){ return Collections.emptyIterator(); } else { return Collections.singletonList((T[])result).iterator(); } }
Example #8
Source File: AbstractMultiDataSetNormalizer.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * Fit an iterator * * @param iterator for the data to iterate over */ public void fit(@NonNull MultiDataSetIterator iterator) { List<S.Builder> featureNormBuilders = new ArrayList<>(); List<S.Builder> labelNormBuilders = new ArrayList<>(); iterator.reset(); while (iterator.hasNext()) { MultiDataSet next = iterator.next(); fitPartial(next, featureNormBuilders, labelNormBuilders); } featureStats = buildList(featureNormBuilders); if (isFitLabel()) { labelStats = buildList(labelNormBuilders); } }
Example #9
Source File: AsyncMultiDataSetIterator.java From deeplearning4j with Apache License 2.0 | 6 votes |
public AsyncMultiDataSetIterator(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue, boolean useWorkspace, DataSetCallback callback, Integer deviceId) { if (queueSize < 2) queueSize = 2; this.callback = callback; this.buffer = queue; this.backedIterator = iterator; this.useWorkspaces = useWorkspace; this.prefetchSize = queueSize; this.workspaceId = "AMDSI_ITER-" + java.util.UUID.randomUUID().toString(); this.deviceId = deviceId; if (iterator.resetSupported() && !iterator.hasNext()) this.backedIterator.reset(); this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, deviceId); thread.setDaemon(true); thread.start(); }
Example #10
Source File: TestComputationGraphNetwork.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testCompGraphDropoutOutputLayers(){ //https://github.com/deeplearning4j/deeplearning4j/issues/6326 ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() .dropOut(0.8) .graphBuilder() .addInputs("in1", "in2") .addVertex("merge", new MergeVertex(), "in1", "in2") .addLayer("lstm", new Bidirectional(Bidirectional.Mode.CONCAT, new LSTM.Builder() .nIn(10).nOut(5) .activation(Activation.TANH) .dropOut(new GaussianNoise(0.05)) .build()) ,"merge") .addLayer("out1", new RnnOutputLayer.Builder().activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10) .nOut(6).build(), "lstm") .addLayer("out2", new RnnOutputLayer.Builder().activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10) .nOut(4).build(), "lstm") .setOutputs("out1", "out2").build(); ComputationGraph net = new ComputationGraph(conf); net.init(); INDArray[] features = new INDArray[]{Nd4j.create(1, 5, 5), Nd4j.create(1, 5, 5)}; INDArray[] labels = new INDArray[]{Nd4j.create(1, 6, 5), Nd4j.create(1, 4, 5)}; MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(features, labels); net.fit(mds); }
Example #11
Source File: ImageMultiPreProcessingScaler.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void preProcess(MultiDataSet multiDataSet) { for( int i=0; i<featureIndices.length; i++ ){ INDArray f = multiDataSet.getFeatures(featureIndices[i]); f.divi(this.maxPixelVal); //Scaled to 0->1 if (this.maxRange - this.minRange != 1) f.muli(this.maxRange - this.minRange); //Scaled to minRange -> maxRange if (this.minRange != 0) f.addi(this.minRange); //Offset by minRange } }
Example #12
Source File: AbstractMultiDataSetNormalizer.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Fit a MultiDataSet (only compute based on the statistics from this {@link MultiDataSet}) * * @param dataSet the dataset to compute on */ public void fit(@NonNull MultiDataSet dataSet) { List<S.Builder> featureNormBuilders = new ArrayList<>(); List<S.Builder> labelNormBuilders = new ArrayList<>(); fitPartial(dataSet, featureNormBuilders, labelNormBuilders); featureStats = buildList(featureNormBuilders); if (isFitLabel()) { labelStats = buildList(labelNormBuilders); } }
Example #13
Source File: EvalTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testMultiOutputEvalCG(){ //Simple sanity check on evaluation ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() .graphBuilder() .addInputs("in") .layer("0", new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build(), "in") .layer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0") .layer("2", new LSTM.Builder().nIn(10).nOut(10).build(), "0") .layer("out1", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1") .layer("out2", new RnnOutputLayer.Builder().nIn(10).nOut(20).activation(Activation.SOFTMAX).build(), "2") .setOutputs("out1", "out2") .build(); ComputationGraph cg = new ComputationGraph(conf); cg.init(); org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet( new INDArray[]{Nd4j.create(10, 1, 10)}, new INDArray[]{Nd4j.create(10, 10, 10), Nd4j.create(10, 20, 10)}); Map<Integer,org.nd4j.evaluation.IEvaluation[]> m = new HashMap<>(); m.put(0, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()}); m.put(1, new org.nd4j.evaluation.IEvaluation[]{new org.nd4j.evaluation.classification.Evaluation()}); cg.evaluate(new SingletonMultiDataSetIterator(mds), m); }
Example #14
Source File: EarlyTerminationMultiDataSetIterator.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public MultiDataSet next() { if (minibatchCount < terminationPoint) { minibatchCount++; return underlyingIterator.next(); } else { throw new RuntimeException("Calls to next have exceeded the allotted number of minibatches."); } }
Example #15
Source File: VasttextDataIterator.java From scava with Eclipse Public License 2.0 | 5 votes |
@Override public MultiDataSet next(int num) { if (!hasNext()) throw new NoSuchElementException("No next elements"); // First: load the next values from the RR / SeqRRs Map<String, List<List<Writable>>> nextRRVals = new HashMap<>(); List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null); for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) { RecordReader rr = entry.getValue(); // Standard case List<List<Writable>> writables = new ArrayList<>(Math.min(num, 100000)); // Min op: in case user puts // batch size >> amount of // data for (int i = 0; i < num && rr.hasNext(); i++) { List<Writable> record; if (collectMetaData) { Record r = rr.nextRecord(); record = r.getRecord(); if (nextMetas.size() <= i) { nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>())); } RecordMetaDataComposableMap map = nextMetas.get(i); map.getMeta().put(entry.getKey(), r.getMetaData()); } else { record = rr.next(); } writables.add(record); } nextRRVals.put(entry.getKey(), writables); } return nextMultiDataSet(nextRRVals, nextMetas); }
Example #16
Source File: IteratorUtils.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Apply a single reader {@link RecordReaderMultiDataSetIterator} to a {@code JavaRDD<List<Writable>>}. * <b>NOTE</b>: The RecordReaderMultiDataSetIterator <it>must</it> use {@link SparkSourceDummyReader} in place of * "real" RecordReader instances * * @param rdd RDD with writables * @param iterator RecordReaderMultiDataSetIterator with {@link SparkSourceDummyReader} readers */ public static JavaRDD<MultiDataSet> mapRRMDSI(JavaRDD<List<Writable>> rdd, RecordReaderMultiDataSetIterator iterator){ checkIterator(iterator, 1, 0); return mapRRMDSIRecords(rdd.map(new Function<List<Writable>,DataVecRecords>(){ @Override public DataVecRecords call(List<Writable> v1) throws Exception { return new DataVecRecords(Collections.singletonList(v1), null); } }), iterator); }
Example #17
Source File: VasttextDataIterator.java From scava with Eclipse Public License 2.0 | 5 votes |
/** * Load a multiple sequence examples to a DataSet, using the provided * RecordMetaData instances. * * @param list * List of RecordMetaData instances to load from. Should have been * produced by the record reader provided to the * SequenceRecordReaderDataSetIterator constructor * @return DataSet with the specified examples * @throws IOException * If an error occurs during loading of the data */ public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException { // First: load the next values from the RR / SeqRRs Map<String, List<List<Writable>>> nextRRVals = new HashMap<>(); List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null); for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) { RecordReader rr = entry.getValue(); List<RecordMetaData> thisRRMeta = new ArrayList<>(); for (RecordMetaData m : list) { RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m; thisRRMeta.add(m2.getMeta().get(entry.getKey())); } List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta); List<List<Writable>> writables = new ArrayList<>(list.size()); for (Record r : fromMeta) { writables.add(r.getRecord()); } nextRRVals.put(entry.getKey(), writables); } return nextMultiDataSet(nextRRVals, nextMetas); }
Example #18
Source File: MultiNormalizerHybrid.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Fit a MultiDataSet (only compute based on the statistics from this dataset) * * @param dataSet the dataset to compute on */ @Override public void fit(@NonNull MultiDataSet dataSet) { Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>(); Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>(); fitPartial(dataSet, inputStatsBuilders, outputStatsBuilders); inputStats = buildAllStats(inputStatsBuilders); outputStats = buildAllStats(outputStatsBuilders); }
Example #19
Source File: BaseTrainingMaster.java From deeplearning4j with Apache License 2.0 | 5 votes |
protected JavaRDD<String> exportIfRequiredMDS(JavaSparkContext sc, JavaRDD<MultiDataSet> trainingData) { ExportSupport.assertExportSupported(sc); if (collectTrainingStats) stats.logExportStart(); //Two possibilities here: // 1. We've seen this RDD before (i.e., multiple epochs training case) // 2. We have not seen this RDD before // (a) And we haven't got any stored data -> simply export // (b) And we previously exported some data from a different RDD -> delete the last data int currentRDDUid = trainingData.id(); //Id is a "A unique ID for this RDD (within its SparkContext)." String baseDir; if (lastExportedRDDId == Integer.MIN_VALUE) { //Haven't seen a RDD<DataSet> yet in this training master -> export data baseDir = exportMDS(trainingData); } else { if (lastExportedRDDId == currentRDDUid) { //Use the already-exported data again for another epoch baseDir = getBaseDirForRDD(trainingData); } else { //The new RDD is different to the last one // Clean up the data for the last one, and export deleteTempDir(sc, lastRDDExportPath); baseDir = exportMDS(trainingData); } } if (collectTrainingStats) stats.logExportEnd(); return sc.textFile(baseDir + "paths/"); }
Example #20
Source File: DummyBlockMultiDataSetIterator.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public MultiDataSet[] next(int maxDatasets) { val list = new ArrayList<MultiDataSet>(maxDatasets); int cnt = 0; while (iterator.hasNext() && cnt < maxDatasets) { list.add(iterator.next()); cnt++; } return list.toArray(new MultiDataSet[list.size()]); }
Example #21
Source File: AsyncMultiDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testVariableTimeSeries2() throws Exception { int numBatches = isIntegrationTests() ? 1000 : 100; int batchSize = isIntegrationTests() ? 32 : 8; int timeStepsMin = 10; int timeStepsMax = isIntegrationTests() ? 500 : 100; int valuesPerTimestep = isIntegrationTests() ? 128 : 16; val iterator = new VariableMultiTimeseriesGenerator(1192, numBatches, batchSize, valuesPerTimestep, timeStepsMin, timeStepsMax, 10); for (int e = 0; e < 10; e++) { iterator.reset(); iterator.hasNext(); val amdsi = new AsyncMultiDataSetIterator(iterator, 2, true); int cnt = 0; while (amdsi.hasNext()) { MultiDataSet mds = amdsi.next(); //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt, mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10); assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25, mds.getLabels()[0].meanNumber().doubleValue(), 1e-10); assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5, mds.getFeaturesMaskArrays()[0].meanNumber().doubleValue(), 1e-10); assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75, mds.getLabelsMaskArrays()[0].meanNumber().doubleValue(), 1e-10); cnt++; } } }
Example #22
Source File: ParameterServerTrainer.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void feedMultiDataSet(@NonNull MultiDataSet dataSet, long time) { // FIXME: this is wrong, and should be fixed if (getModel() instanceof ComputationGraph) { ComputationGraph computationGraph = (ComputationGraph) getModel(); computationGraph.fit(dataSet); } else { throw new IllegalArgumentException("MultiLayerNetworks can't fit multi datasets"); } log.info("Sending parameters"); //send the updated params parameterServerClient.pushNDArray(getModel().params()); }
Example #23
Source File: MultiNormalizerHybrid.java From nd4j with Apache License 2.0 | 5 votes |
/** * Fit a MultiDataSet (only compute based on the statistics from this dataset) * * @param dataSet the dataset to compute on */ @Override public void fit(@NonNull MultiDataSet dataSet) { Map<Integer, NormalizerStats.Builder> inputStatsBuilders = new HashMap<>(); Map<Integer, NormalizerStats.Builder> outputStatsBuilders = new HashMap<>(); fitPartial(dataSet, inputStatsBuilders, outputStatsBuilders); inputStats = buildAllStats(inputStatsBuilders); outputStats = buildAllStats(outputStatsBuilders); }
Example #24
Source File: MultiDataSetIteratorAdapter.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public MultiDataSet next(int i) { MultiDataSet mds = iter.next(i).toMultiDataSet(); if (preProcessor != null) preProcessor.preProcess(mds); return mds; }
Example #25
Source File: OpExecOrderListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) { String opName = op.getName(); if(!opSet.contains(opName)){ opNamesList.add(opName); opSet.add(opName); } }
Example #26
Source File: SparkAMDSI.java From deeplearning4j with Apache License 2.0 | 4 votes |
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue) { this(iterator, queueSize, queue, true); }
Example #27
Source File: ImageMultiPreProcessingScaler.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public void revert(MultiDataSet toRevert) { revertFeatures(toRevert.getFeatures(), toRevert.getFeaturesMaskArrays()); }
Example #28
Source File: TestBertIterator.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testSentencePairsSingle() throws IOException { boolean prependAppend; int numOfSentences; TestSentenceHelper testHelper = new TestSentenceHelper(); int shortL = testHelper.getShortestL(); int longL = testHelper.getLongestL(); Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple; MultiDataSet fromPair, leftSide, rightSide; // check for pair max length exactly equal to sum of lengths - pop neither no padding // should be the same as hstack with segment ids 1 for second sentence prependAppend = true; numOfSentences = 1; multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend, numOfSentences); fromPair = multiDataSetTriple.getFirst(); leftSide = multiDataSetTriple.getSecond(); rightSide = multiDataSetTriple.getThird(); assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); rightSide.getFeatures(1).addi(1); //add 1 for right side segment ids assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); //check for pair max length greater than sum of lengths - pop neither with padding // features should be the same as hstack of shorter and longer padded with prepend/append // segment id should 1 only in the longer for part of the length of the sentence prependAppend = true; numOfSentences = 1; multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend, numOfSentences); fromPair = multiDataSetTriple.getFirst(); leftSide = multiDataSetTriple.getSecond(); rightSide = multiDataSetTriple.getThird(); assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); rightSide.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); //check for pair max length less than shorter sentence - pop both //should be the same as hstack with segment ids 1 for second sentence if no prepend/append int maxL = 5;//checking odd numOfSentences = 3; prependAppend = false; multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend, numOfSentences); fromPair = multiDataSetTriple.getFirst(); leftSide = multiDataSetTriple.getSecond(); rightSide = multiDataSetTriple.getThird(); assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); rightSide.getFeatures(1).addi(1); assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); }
Example #29
Source File: BaseEvaluationListener.java From deeplearning4j with Apache License 2.0 | 4 votes |
/** * See {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)} */ public void activationAvailableEvaluations(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation){ //No op }
Example #30
Source File: BaseListener.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) { //No op }