org.nd4j.linalg.dataset.DataSet Java Examples
The following examples show how to use
org.nd4j.linalg.dataset.DataSet.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: PathSparkDataSetIterator.java From deeplearning4j with Apache License 2.0 | 7 votes |
@Override public DataSet next() { DataSet ds; if (preloadedDataSet != null) { ds = preloadedDataSet; preloadedDataSet = null; } else { ds = load(iter.next()); } totalOutcomes = ds.getLabels() == null ? 0 : (int) ds.getLabels().size(1); //May be null for layerwise pretraining inputColumns = (int) ds.getFeatures().size(1); batch = ds.numExamples(); if (preprocessor != null) preprocessor.preProcess(ds); return ds; }
Example #2
Source File: ScoreUtil.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * Score based on the loss function * @param model the model to score with * @param testData the test data to score * @param average whether to average the score * for the whole batch or not * @return the score for the given test set */ public static double score(ComputationGraph model, DataSetIterator testData, boolean average) { //TODO: do this properly taking into account division by N, L1/L2 etc double sumScore = 0.0; int totalExamples = 0; while (testData.hasNext()) { DataSet ds = testData.next(); int numExamples = ds.numExamples(); sumScore += numExamples * model.score(ds); totalExamples += numExamples; } if (!average) return sumScore; return sumScore / totalExamples; }
Example #3
Source File: RandomDataSetIteratorTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testDSI(){ DataSetIterator iter = new RandomDataSetIterator(5, new long[]{3,4}, new long[]{3,5}, RandomDataSetIterator.Values.RANDOM_UNIFORM, RandomDataSetIterator.Values.ONE_HOT); int count = 0; while(iter.hasNext()){ count++; DataSet ds = iter.next(); assertArrayEquals(new long[]{3,4}, ds.getFeatures().shape()); assertArrayEquals(new long[]{3,5}, ds.getLabels().shape()); assertTrue(ds.getFeatures().minNumber().doubleValue() >= 0.0 && ds.getFeatures().maxNumber().doubleValue() <= 1.0); assertEquals(Nd4j.ones(3), ds.getLabels().sum(1)); } assertEquals(5, count); }
Example #4
Source File: TransferLearningHelper.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * During training frozen vertices/layers can be treated as "featurizing" the input * The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate * quickly on the smaller unfrozen part of the model * Currently does not support datasets with feature masks * * @param input multidataset to feed into the computation graph with frozen layer vertices * @return a multidataset with input features that are the outputs of the frozen layer vertices and the original labels. */ public DataSet featurize(DataSet input) { if (isGraph) { //trying to featurize for a computation graph if (origGraph.getNumInputArrays() > 1 || origGraph.getNumOutputArrays() > 1) { throw new IllegalArgumentException( "Input or output size to a computation graph is greater than one. Requires use of a MultiDataSet."); } else { if (input.getFeaturesMaskArray() != null) { throw new IllegalArgumentException( "Currently cannot support featurizing datasets with feature masks"); } MultiDataSet inbW = new MultiDataSet(new INDArray[] {input.getFeatures()}, new INDArray[] {input.getLabels()}, null, new INDArray[] {input.getLabelsMaskArray()}); MultiDataSet ret = featurize(inbW); return new DataSet(ret.getFeatures()[0], input.getLabels(), ret.getLabelsMaskArrays()[0], input.getLabelsMaskArray()); } } else { if (input.getFeaturesMaskArray() != null) throw new UnsupportedOperationException("Feature masks not supported with featurizing currently"); return new DataSet(origMLN.feedForwardToLayer(frozenInputLayer + 1, input.getFeatures(), false) .get(frozenInputLayer + 1), input.getLabels(), null, input.getLabelsMaskArray()); } }
Example #5
Source File: DrawMnist.java From Canova with Apache License 2.0 | 6 votes |
public static void drawMnist(DataSet mnist,INDArray reconstruct) throws InterruptedException { for(int j = 0; j < mnist.numExamples(); j++) { INDArray draw1 = mnist.get(j).getFeatureMatrix().mul(255); INDArray reconstructed2 = reconstruct.getRow(j); INDArray draw2 = Sampling.binomial(reconstructed2, 1, new MersenneTwister(123)).mul(255); DrawReconstruction d = new DrawReconstruction(draw1); d.title = "REAL"; d.draw(); DrawReconstruction d2 = new DrawReconstruction(draw2,1000,1000); d2.title = "TEST"; d2.draw(); Thread.sleep(1000); d.frame.dispose(); d2.frame.dispose(); } }
Example #6
Source File: CnnTextEmbeddingInstanceIteratorTest.java From wekaDeeplearning4j with GNU General Public License v3.0 | 6 votes |
/** * Test getDataSetIterator */ @Test public void testGetIteratorNominalClass() throws Exception { final Instances data = DatasetLoader.loadReutersMinimal(); final int batchSize = 1; final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize); Set<Integer> labels = new HashSet<>(); for (Instance inst : data) { int label = Integer.parseInt(inst.stringValue(data.classIndex())); final DataSet next = Utils.getNext(it); int itLabel = next.getLabels().argMax().getInt(0); Assert.assertEquals(label, itLabel); labels.add(label); } final Set<Integer> collect = it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet()); Assert.assertEquals(2, labels.size()); Assert.assertTrue(labels.containsAll(collect)); Assert.assertTrue(collect.containsAll(labels)); }
Example #7
Source File: DataSetExportFunction.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void call(Iterator<DataSet> iter) throws Exception { String jvmuid = UIDProvider.getJVMUID(); uid = Thread.currentThread().getId() + jvmuid.substring(0, Math.min(8, jvmuid.length())); Configuration c = conf == null ? DefaultHadoopConfig.get() : conf.getValue().getConfiguration(); while (iter.hasNext()) { DataSet next = iter.next(); String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin"; String path = outputDir.getPath(); URI uri = new URI(path + (path.endsWith("/") || path.endsWith("\\") ? "" : "/") + filename); FileSystem file = FileSystem.get(uri, c); try (FSDataOutputStream out = file.create(new Path(uri))) { next.save(out); } } }
Example #8
Source File: ModelSerializerTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testJavaSerde_1() throws Exception { int nIn = 5; int nOut = 6; ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) .graphBuilder() .addInputs("in") .layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).build(), "in") .setOutputs("0") .validateOutputLayerConfig(false) .build(); ComputationGraph net = new ComputationGraph(conf); net.init(); DataSet dataSet = trivialDataSet(); NormalizerStandardize norm = new NormalizerStandardize(); norm.fit(dataSet); val b = SerializationUtils.serialize(net); ComputationGraph restored = SerializationUtils.deserialize(b); assertEquals(net, restored); }
Example #9
Source File: MultiDataSetWrapperIterator.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public DataSet next() { MultiDataSet mds = iterator.next(); if (mds.getFeatures().length > 1 || mds.getLabels().length > 1) throw new UnsupportedOperationException( "This iterator is able to convert MultiDataSet with number of inputs/outputs of 1"); INDArray features = mds.getFeatures()[0]; INDArray labels = mds.getLabels() != null ? mds.getLabels()[0] : features; INDArray fMask = mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null; INDArray lMask = mds.getLabelsMaskArrays() != null ? mds.getLabelsMaskArrays()[0] : null; DataSet ds = new DataSet(features, labels, fMask, lMask); if (preProcessor != null) preProcessor.preProcess(ds); return ds; }
Example #10
Source File: ImageInstanceIteratorTest.java From wekaDeeplearning4j with GNU General Public License v3.0 | 6 votes |
/** * Test getDataSetIterator */ @Test public void testGetIterator() throws Exception { final Instances metaData = DatasetLoader.loadMiniMnistMeta(); this.idi.setImagesLocation(new File("datasets/nominal/mnist-minimal")); final int batchSize = 1; final DataSetIterator it = this.idi.getDataSetIterator(metaData, SEED, batchSize); Set<Integer> labels = new HashSet<>(); for (Instance inst : metaData) { int label = Integer.parseInt(inst.stringValue(1)); final DataSet next = Utils.getNext(it); int itLabel = next.getLabels().argMax().getInt(0); Assert.assertEquals(label, itLabel); labels.add(label); } final List<Integer> collect = it.getLabels().stream().map(Integer::valueOf).collect(Collectors.toList()); Assert.assertEquals(10, labels.size()); Assert.assertTrue(labels.containsAll(collect)); Assert.assertTrue(collect.containsAll(labels)); }
Example #11
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test @Ignore public void specialRRTest4() throws Exception { RecordReader rr = new SpecialImageRecordReader(25000, 10, 3, 224, 224); RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 128); int cnt = 0; int examples = 0; while (rrdsi.hasNext()) { DataSet ds = rrdsi.next(); assertEquals(128, ds.numExamples()); for (int i = 0; i < ds.numExamples(); i++) { INDArray example = ds.getFeatures().tensorAlongDimension(i, 1, 2, 3).dup(); // assertEquals("Failed on DataSet [" + cnt + "], example [" + i + "]", (double) examples, example.meanNumber().doubleValue(), 0.01); // assertEquals("Failed on DataSet [" + cnt + "], example [" + i + "]", (double) examples, ds.getLabels().getRow(i).meanNumber().doubleValue(), 0.01); examples++; } cnt++; } }
Example #12
Source File: MultiLayerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testOutput() throws Exception { Nd4j.getRandom().setSeed(12345); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .weightInit(WeightInit.XAVIER).seed(12345L).list() .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) .setInputType(InputType.convolutional(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); DataSetIterator fullData = new MnistDataSetIterator(1, 2); net.fit(fullData); fullData.reset(); DataSet expectedSet = fullData.next(2); INDArray expectedOut = net.output(expectedSet.getFeatures(), false); fullData.reset(); INDArray actualOut = net.output(fullData); assertEquals(expectedOut, actualOut); }
Example #13
Source File: LFWLoader.java From Canova with Apache License 2.0 | 5 votes |
public DataSet convertListPairs(List<DataSet> images) { INDArray inputs = Nd4j.create(images.size(), numPixelColumns); INDArray outputs = Nd4j.create(images.size(),numNames); for(int i = 0; i < images.size(); i++) { inputs.putRow(i,images.get(i).getFeatureMatrix()); outputs.putRow(i,images.get(i).getLabels()); } return new DataSet(inputs,outputs); }
Example #14
Source File: DL4JSequenceRecommender.java From inception with Apache License 2.0 | 5 votes |
private MultiLayerNetwork train(List<Sample> aTrainingData, Object2IntMap<String> aTagset) throws IOException { // Configure the neural network MultiLayerNetwork model = createConfiguredNetwork(traits, wordVectors.dimensions()); final int limit = traits.getTrainingSetSizeLimit(); final int batchSize = traits.getBatchSize(); // First vectorizing all sentences and then passing them to the model would consume // huge amounts of memory. Thus, every sentence is vectorized and then immediately // passed on to the model. nextEpoch: for (int epoch = 0; epoch < traits.getnEpochs(); epoch++) { int sentNum = 0; Iterator<Sample> sampleIterator = aTrainingData.iterator(); while (sampleIterator.hasNext()) { List<DataSet> batch = new ArrayList<>(); while (sampleIterator.hasNext() && batch.size() < batchSize && sentNum < limit) { Sample sample = sampleIterator.next(); DataSet trainingData = vectorize(asList(sample), aTagset, true); batch.add(trainingData); sentNum++; } model.fit(new ListDataSetIterator<DataSet>(batch, batch.size())); log.trace("Epoch {}: processed {} of {} sentences", epoch, sentNum, aTrainingData.size()); if (sentNum >= limit) { continue nextEpoch; } } } return model; }
Example #15
Source File: CompositeDataSetPreProcessorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void when_dataSetIsEmpty_expect_emptyDataSet() { // Assemble CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(); DataSet ds = new DataSet(null, null); // Act sut.preProcess(ds); // Assert assertTrue(ds.isEmpty()); }
Example #16
Source File: SparkADSI.java From deeplearning4j with Apache License 2.0 | 5 votes |
public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace, DataSetCallback callback, Integer deviceId) { this(); if (queueSize < 2) queueSize = 2; this.deviceId = deviceId; this.callback = callback; this.useWorkspace = useWorkspace; this.buffer = queue; this.prefetchSize = queueSize; this.backedIterator = iterator; this.workspaceId = "SADSI_ITER-" + java.util.UUID.randomUUID().toString(); if (iterator.resetSupported()) this.backedIterator.reset(); context = TaskContext.get(); this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread()); /** * We want to ensure, that background thread will have the same thread->device affinity, as master thread */ thread.setDaemon(true); thread.start(); }
Example #17
Source File: ConvolutionInstancesIteratorTest.java From wekaDeeplearning4j with GNU General Public License v3.0 | 5 votes |
/** * Test getDataSetIterator */ @Test public void testGetIterator() throws Exception { final int batchSize = 1; final DataSetIterator it = this.cii.getDataSetIterator(mnistMiniArff, SEED, batchSize); Set<Integer> labels = new HashSet<>(); for (int i = 0; i < mnistMiniArff.size(); i++) { Instance inst = mnistMiniArff.get(i); int instLabel = Integer.parseInt(inst.stringValue(inst.numAttributes() - 1)); final DataSet next = Utils.getNext(it); int dsLabel = next.getLabels().argMax().getInt(0); Assert.assertEquals(instLabel, dsLabel); labels.add(instLabel); INDArray reshaped = next.getFeatures().reshape(1, inst.numAttributes() - 1); // Compare each attribute value for (int j = 0; j < inst.numAttributes() - 1; j++) { double instVal = inst.value(j); double dsVal = reshaped.getDouble(j); Assert.assertEquals(instVal, dsVal, 10e-8); } } final List<Integer> collect = it.getLabels().stream().map(Integer::valueOf).collect(Collectors.toList()); Assert.assertEquals(10, labels.size()); Assert.assertTrue(labels.containsAll(collect)); Assert.assertTrue(collect.containsAll(labels)); }
Example #18
Source File: FileSplitDataSetIterator.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public DataSet next() { // long time1 = System.nanoTime(); DataSet ds = callback.call(files.get(counter.getAndIncrement())); if (preProcessor != null && ds != null) preProcessor.preProcess(ds); // long time2 = System.nanoTime(); // if (counter.get() % 5 == 0) // log.info("Device: [{}]; Time: [{}] ns;", Nd4j.getAffinityManager().getDeviceForCurrentThread(), time2 - time1); return ds; }
Example #19
Source File: BenchmarkDataSetIterator.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * @param example DataSet to return on each call of next() * @param totalIterations Total number of iterations */ public BenchmarkDataSetIterator(DataSet example, int totalIterations) { this.baseFeatures = example.getFeatures().dup(); this.baseLabels = example.getLabels().dup(); Nd4j.getExecutioner().commit(); this.limit = totalIterations; }
Example #20
Source File: BagOfWordsVectorizer.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Text coming from an input stream considered as one document * * @param is the input stream to read from * @param label the label to assign * @return a dataset with a applyTransformToDestination of weights(relative to impl; could be word counts or tfidf scores) */ @Override public DataSet vectorize(InputStream is, String label) { try { BufferedReader reader = new BufferedReader(new InputStreamReader(is, "UTF-8")); String line = ""; StringBuilder builder = new StringBuilder(); while ((line = reader.readLine()) != null) { builder.append(line); } return vectorize(builder.toString(), label); } catch (Exception e) { throw new RuntimeException(e); } }
Example #21
Source File: MLLibUtil.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Converts a continuous JavaRDD LabeledPoint to a JavaRDD DataSet. * @param data JavaRdd LabeledPoint * @param preCache boolean pre-cache rdd before operation * @return */ public static JavaRDD<DataSet> fromContinuousLabeledPoint(JavaRDD<LabeledPoint> data, boolean preCache) { if (preCache && !data.getStorageLevel().useMemory()) { data.cache(); } return data.map(new Function<LabeledPoint, DataSet>() { @Override public DataSet call(LabeledPoint lp) { return convertToDataset(lp); } }); }
Example #22
Source File: BaseDataFetcher.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Initializes this data transform fetcher from the passed in datasets * * @param examples the examples to use */ protected void initializeCurrFromList(List<DataSet> examples) { if (examples.isEmpty()) log.warn("Warning: empty dataset from the fetcher"); INDArray inputs = createInputMatrix(examples.size()); INDArray labels = createOutputMatrix(examples.size()); for (int i = 0; i < examples.size(); i++) { inputs.putRow(i, examples.get(i).getFeatures()); labels.putRow(i, examples.get(i).getLabels()); } curr = new DataSet(inputs, labels); }
Example #23
Source File: SparkUtils.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Randomly shuffle the examples in each DataSet object, and recombine them into new DataSet objects * with the specified BatchSize * * @param rdd DataSets to shuffle/recombine * @param newBatchSize New batch size for the DataSet objects, after shuffling/recombining * @param numPartitions Number of partitions to use when splitting/recombining * @return A new {@link JavaRDD<DataSet>}, with the examples shuffled/combined in each */ public static JavaRDD<DataSet> shuffleExamples(JavaRDD<DataSet> rdd, int newBatchSize, int numPartitions) { //Step 1: split into individual examples, mapping to a pair RDD (random key in range 0 to numPartitions) JavaPairRDD<Integer, DataSet> singleExampleDataSets = rdd.flatMapToPair(new SplitDataSetExamplesPairFlatMapFunction(numPartitions)); //Step 2: repartition according to the random keys singleExampleDataSets = singleExampleDataSets.partitionBy(new HashPartitioner(numPartitions)); //Step 3: Recombine return singleExampleDataSets.values().mapPartitions(new BatchDataSetsFunction(newBatchSize)); }
Example #24
Source File: SingletonDataSetIterator.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public DataSet next() { if (!hasNext) { throw new NoSuchElementException("No elements remaining"); } hasNext = false; if (preProcessor != null && !preprocessed) { preProcessor.preProcess(dataSet); preprocessed = true; } return dataSet; }
Example #25
Source File: ScrollableDataSetIterator.java From deeplearning4j with Apache License 2.0 | 5 votes |
public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter, AtomicBoolean resetPending, DataSet firstTrain, int[] itemsPerPart) { this.thisPart = num; this.bottom = itemsPerPart[0]; this.top = bottom + itemsPerPart[1]; this.itemsPerPart = top; this.backedIterator = backedIterator; this.counter = counter; //this.resetPending = resetPending; this.firstTrain = firstTrain; //this.totalExamples = totalExamples; this.current = 0; }
Example #26
Source File: MLLibUtil.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Convert an rdd of data set in to labeled point. * @param data the dataset to convert * @param preCache boolean pre-cache rdd before operation * @return an rdd of labeled point */ public static JavaRDD<LabeledPoint> fromDataSet(JavaRDD<DataSet> data, boolean preCache) { if (preCache && !data.getStorageLevel().useMemory()) { data.cache(); } return data.map(new Function<DataSet, LabeledPoint>() { @Override public LabeledPoint call(DataSet dataSet) { return toLabeledPoint(dataSet); } }); }
Example #27
Source File: EvaluativeListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
public EvaluativeListener(@NonNull DataSet dataSet, int frequency, @NonNull InvocationType type, IEvaluation... evaluations) { this.ds = dataSet; this.frequency = frequency; this.evaluations = evaluations; this.invocationType = type; }
Example #28
Source File: TestMasking.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testCompGraphEvalWithMask() { int minibatch = 3; int layerSize = 6; int nIn = 5; int nOut = 4; ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().updater(new NoOp()) .dist(new NormalDistribution(0, 1)).seed(12345) .graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) .build(), "in") .addLayer("1", new OutputLayer.Builder().nIn(layerSize).nOut(nOut) .lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID) .build(), "0") .setOutputs("1").build(); ComputationGraph graph = new ComputationGraph(conf2); graph.init(); INDArray f = Nd4j.create(minibatch, nIn); INDArray l = Nd4j.create(minibatch, nOut); INDArray lMask = Nd4j.ones(minibatch, nOut); DataSet ds = new DataSet(f, l, null, lMask); DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds).iterator()); EvaluationBinary eb = new EvaluationBinary(); graph.doEvaluation(iter, eb); }
Example #29
Source File: ScoreFlatMapFunctionCGDataSet.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public Iterator<Tuple2<Long, Double>> call(Iterator<DataSet> dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { return Collections.singletonList(new Tuple2<>(0L, 0.0)).iterator(); } DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json)); network.init(); INDArray val = params.value().unsafeDuplication(); //.value() is shared by all executors on single machine -> OK, as params are not changed in score function if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcast set parameters"); network.setParams(val); List<Tuple2<Long, Double>> out = new ArrayList<>(); while (iter.hasNext()) { DataSet ds = iter.next(); double score = network.score(ds, false); long numExamples = ds.getFeatures().size(0); out.add(new Tuple2<>(numExamples, score * numExamples)); } Nd4j.getExecutioner().commit(); return out.iterator(); }
Example #30
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testSeqRRDSIArrayWritableOneReaderRegression() { //Regression, where the output is an array writable List<List<Writable>> sequence1 = new ArrayList<>(); sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})), new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, new long[]{1,3})))); sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})), new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, new long[]{1,3})))); List<List<Writable>> sequence2 = new ArrayList<>(); sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})), new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, new long[]{1,3})))); sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})), new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, new long[]{1,3})))); SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rr, 2, -1, 1, true); DataSet ds = iter.next(); INDArray expFeatures = Nd4j.create(2, 3, 2); //2 examples, 3 values per time step, 2 time steps expFeatures.tensorAlongDimension(0, 1, 2).assign(Nd4j.create(new double[][] {{1, 4}, {2, 5}, {3, 6}})); expFeatures.tensorAlongDimension(1, 1, 2).assign(Nd4j.create(new double[][] {{7, 10}, {8, 11}, {9, 12}})); INDArray expLabels = Nd4j.create(2, 3, 2); expLabels.tensorAlongDimension(0, 1, 2) .assign(Nd4j.create(new double[][] {{100, 400}, {200, 500}, {300, 600}})); expLabels.tensorAlongDimension(1, 1, 2) .assign(Nd4j.create(new double[][] {{700, 1000}, {800, 1100}, {900, 1200}})); assertEquals(expFeatures, ds.getFeatures()); assertEquals(expLabels, ds.getLabels()); }