Java Code Examples for org.nd4j.linalg.dataset.api.MultiDataSet#getLabels()

The following examples show how to use org.nd4j.linalg.dataset.api.MultiDataSet#getLabels() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: DefaultCallback.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void call(MultiDataSet multiDataSet) {
    if (multiDataSet != null) {
        if (multiDataSet.getFeatures() != null)
            for (int i = 0; i < multiDataSet.getFeatures().length; i++)
                Nd4j.getAffinityManager().ensureLocation(multiDataSet.getFeatures()[i],
                                AffinityManager.Location.DEVICE);

        if (multiDataSet.getLabels() != null)
            for (int i = 0; i < multiDataSet.getLabels().length; i++)
                Nd4j.getAffinityManager().ensureLocation(multiDataSet.getLabels()[i],
                                AffinityManager.Location.DEVICE);

        if (multiDataSet.getFeaturesMaskArrays() != null)
            for (int i = 0; i < multiDataSet.getFeaturesMaskArrays().length; i++)
                Nd4j.getAffinityManager().ensureLocation(multiDataSet.getFeaturesMaskArrays()[i],
                                AffinityManager.Location.DEVICE);

        if (multiDataSet.getLabelsMaskArrays() != null)
            for (int i = 0; i < multiDataSet.getLabelsMaskArrays().length; i++)
                Nd4j.getAffinityManager().ensureLocation(multiDataSet.getLabelsMaskArrays()[i],
                                AffinityManager.Location.DEVICE);
    }
}
 
Example 2
Source File: MultiLayerNetwork.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public void fit(MultiDataSet dataSet) {
    if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) {
        INDArray features = dataSet.getFeatures(0);
        INDArray labels = dataSet.getLabels(0);
        INDArray fMask = null;
        INDArray lMask = null;

        if (dataSet.getFeaturesMaskArrays() != null)
            fMask = dataSet.getFeaturesMaskArrays()[0];

        if (dataSet.getFeaturesMaskArrays() != null)
            lMask = dataSet.getLabelsMaskArrays()[0];

        DataSet ds = new DataSet(features, labels, fMask, lMask);
        fit(ds);
    } else {
        throw new DL4JInvalidInputException(
                "MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." +
                        "Please consider use of ComputationGraph");
    }
}
 
Example 3
Source File: MultiDataSetWrapperIterator.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public DataSet next() {
    MultiDataSet mds = iterator.next();
    if (mds.getFeatures().length > 1 || mds.getLabels().length > 1)
        throw new UnsupportedOperationException(
                        "This iterator is able to convert MultiDataSet with number of inputs/outputs of 1");

    INDArray features = mds.getFeatures()[0];
    INDArray labels = mds.getLabels() != null ? mds.getLabels()[0] : features;
    INDArray fMask = mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null;
    INDArray lMask = mds.getLabelsMaskArrays() != null ? mds.getLabelsMaskArrays()[0] : null;

    DataSet ds = new DataSet(features, labels, fMask, lMask);

    if (preProcessor != null)
        preProcessor.preProcess(ds);

    return ds;
}
 
Example 4
Source File: Main.java    From twse-captcha-solver-dl4j with MIT License 5 votes vote down vote up
public static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) {
  int sumCount = 0;
  int correctCount = 0;

  List<String> labelList =
      Arrays.asList(
          "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G",
          "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X",
          "Y", "Z");

  while (iterator.hasNext()) {
    MultiDataSet mds = iterator.next();
    INDArray[] output = model.output(mds.getFeatures());
    INDArray[] labels = mds.getLabels();
    int dataNum = batchSize > output[0].rows() ? output[0].rows() : batchSize;
    for (int dataIndex = 0; dataIndex < dataNum; dataIndex++) {
      String reLabel = "";
      String peLabel = "";
      INDArray preOutput = null;
      INDArray realLabel = null;
      for (int digit = 0; digit < 5; digit++) {
        preOutput = output[digit].getRow(dataIndex);
        peLabel += labelList.get(Nd4j.argMax(preOutput, 1).getInt(0));

        realLabel = labels[digit].getRow(dataIndex);
 reLabel += labelList.get(Nd4j.argMax(realLabel, 1).getInt(0));
      }
      if (peLabel.equals(reLabel)) {
        correctCount++;
      }
      sumCount++;
      logger.info(
          "real image {}  prediction {} status {}", reLabel, peLabel, peLabel.equals(reLabel));
    }
  }
  iterator.reset();
  System.out.println(
      "validate result : sum count =" + sumCount + " correct count=" + correctCount);
}
 
Example 5
Source File: UnderSamplingByMaskingMultiDataSetPreProcessor.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(MultiDataSet multiDataSet) {

    for (Integer index : targetMinorityDistMap.keySet()) {
        INDArray label = multiDataSet.getLabels(index);
        INDArray labelMask = multiDataSet.getLabelsMaskArray(index);
        double targetMinorityDist = targetMinorityDistMap.get(index);
        int minorityLabel = minorityLabelMap.get(index);
        multiDataSet.setLabelsMaskArray(index, adjustMasks(label, labelMask, minorityLabel, targetMinorityDist));
    }

}
 
Example 6
Source File: UnderSamplingByMaskingMultiDataSetPreProcessor.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public void preProcess(MultiDataSet multiDataSet) {

    for (Integer index : targetMinorityDistMap.keySet()) {
        INDArray label = multiDataSet.getLabels(index);
        INDArray labelMask = multiDataSet.getLabelsMaskArray(index);
        double targetMinorityDist = targetMinorityDistMap.get(index);
        int minorityLabel = minorityLabelMap.get(index);
        multiDataSet.setLabelsMaskArray(index, adjustMasks(label, labelMask, minorityLabel, targetMinorityDist));
    }

}
 
Example 7
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSplittingCSV() throws Exception {
    //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
    //Inputs: columns 0 and 1-2
    //Outputs: columns 3, and 4->OneHot
    //need to manually extract
    RecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);

    RecordReader rr2 = new CSVRecordReader(0, ',');
    rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));

    MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2)
                    .addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3)
                    .addOutputOneHot("reader", 4, 3).build();

    while (rrdsi.hasNext()) {
        DataSet ds = rrdsi.next();
        INDArray fds = ds.getFeatures();
        INDArray lds = ds.getLabels();

        MultiDataSet mds = rrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(2, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();

        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++)
            assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++)
            assertNotNull(lmds[i]);

        //Get the subsets of the original iris data
        INDArray expIn1 = fds.get(all(), interval(0,0,true));
        INDArray expIn2 = fds.get(all(), interval(1, 2, true));
        INDArray expOut1 = fds.get(all(), interval(3,3,true));
        INDArray expOut2 = lds;

        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(expOut1, lmds[0]);
        assertEquals(expOut2, lmds[1]);
    }
    assertFalse(rrmdsi.hasNext());
}
 
Example 8
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSplittingCSVSequence() throws Exception {
    //Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
    // as standard one-hot output
    //need to manually extract
    File rootDir = temporaryFolder.newFolder();
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
    }

    String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
    String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");

    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    SequenceRecordReaderDataSetIterator iter =
                    new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);

    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));

    MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1)
                    .addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2)
                    .addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();

    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray fds = ds.getFeatures();
        INDArray lds = ds.getLabels();

        MultiDataSet mds = srrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(1, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();

        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++)
            assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++)
            assertNotNull(lmds[i]);

        INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all());
        INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all());

        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(lds, lmds[0]);
    }
    assertFalse(srrmdsi.hasNext());
}
 
Example 9
Source File: RecordReaderMultiDataSetIteratorTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testTimeSeriesRandomOffset() {
    //2 in, 2 out, 3 total sequences of length [1,3,5]

    List<List<Writable>> seq1 =
                    Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(2.0)));
    List<List<Writable>> seq2 =
                    Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(10.0), new DoubleWritable(11.0)),
                                    Arrays.<Writable>asList(new DoubleWritable(20.0), new DoubleWritable(21.0)),
                                    Arrays.<Writable>asList(new DoubleWritable(30.0), new DoubleWritable(31.0)));
    List<List<Writable>> seq3 =
                    Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(100.0), new DoubleWritable(101.0)),
                                    Arrays.<Writable>asList(new DoubleWritable(200.0), new DoubleWritable(201.0)),
                                    Arrays.<Writable>asList(new DoubleWritable(300.0), new DoubleWritable(301.0)),
                                    Arrays.<Writable>asList(new DoubleWritable(400.0), new DoubleWritable(401.0)),
                                    Arrays.<Writable>asList(new DoubleWritable(500.0), new DoubleWritable(501.0)));

    Collection<List<List<Writable>>> seqs = Arrays.asList(seq1, seq2, seq3);

    SequenceRecordReader rr = new CollectionSequenceRecordReader(seqs);

    RecordReaderMultiDataSetIterator rrmdsi =
                    new RecordReaderMultiDataSetIterator.Builder(3).addSequenceReader("rr", rr).addInput("rr", 0, 0)
                                    .addOutput("rr", 1, 1).timeSeriesRandomOffset(true, 1234L).build();


    Random r = new Random(1234); //Provides seed for each minibatch
    long seed = r.nextLong();
    Random r2 = new Random(seed); //Use same RNG seed in new RNG for each minibatch
    int expOffsetSeq1 = r2.nextInt(5 - 1 + 1); //0 to 4 inclusive
    int expOffsetSeq2 = r2.nextInt(5 - 3 + 1);
    int expOffsetSeq3 = 0; //Longest TS, always 0
    //With current seed: 3, 1, 0
    //        System.out.println(expOffsetSeq1 + "\t" + expOffsetSeq2 + "\t" + expOffsetSeq3);

    MultiDataSet mds = rrmdsi.next();

    INDArray expMask = Nd4j.create(new double[][] {{0, 0, 0, 1, 0}, {0, 1, 1, 1, 0}, {1, 1, 1, 1, 1}});

    assertEquals(expMask, mds.getFeaturesMaskArray(0));
    assertEquals(expMask, mds.getLabelsMaskArray(0));

    INDArray f = mds.getFeatures(0);
    INDArray l = mds.getLabels(0);

    INDArray expF1 = Nd4j.create(new double[] {1.0}, new int[]{1,1});
    INDArray expL1 = Nd4j.create(new double[] {2.0}, new int[]{1,1});

    INDArray expF2 = Nd4j.create(new double[] {10, 20, 30}, new int[]{1,3});
    INDArray expL2 = Nd4j.create(new double[] {11, 21, 31}, new int[]{1,3});

    INDArray expF3 = Nd4j.create(new double[] {100, 200, 300, 400, 500}, new int[]{1,5});
    INDArray expL3 = Nd4j.create(new double[] {101, 201, 301, 401, 501}, new int[]{1,5});

    assertEquals(expF1, f.get(point(0), all(),
                    NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
    assertEquals(expL1, l.get(point(0), all(),
                    NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));

    assertEquals(expF2, f.get(point(1), all(),
                    NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
    assertEquals(expL2, l.get(point(1), all(),
                    NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));

    assertEquals(expF3, f.get(point(2), all(),
                    NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
    assertEquals(expL3, l.get(point(2), all(),
                    NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
}