Java Code Examples for org.nd4j.linalg.factory.Nd4j#shuffle()

The following examples show how to use org.nd4j.linalg.factory.Nd4j#shuffle() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: DataSet.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Shuffles the dataset in place, given a seed for a random number generator. For reproducibility
 * This will modify the dataset in place!!
 *
 * @param seed Seed to use for the random Number Generator
 */
public void shuffle(long seed) {
    // just skip shuffle if there's only 1 example
    if (numExamples() < 2)
        return;

    //note here we use the same seed with different random objects guaranteeing same order

    List<INDArray> arrays = new ArrayList<>();
    List<int[]> dimensions = new ArrayList<>();

    arrays.add(getFeatures());
    dimensions.add(ArrayUtil.range(1, getFeatures().rank()));

    arrays.add(getLabels());
    dimensions.add(ArrayUtil.range(1, getLabels().rank()));

    if (featuresMask != null) {
        arrays.add(getFeaturesMaskArray());
        dimensions.add(ArrayUtil.range(1, getFeaturesMaskArray().rank()));
    }

    if (labelsMask != null) {
        arrays.add(getLabelsMaskArray());
        dimensions.add(ArrayUtil.range(1, getLabelsMaskArray().rank()));
    }

    Nd4j.shuffle(arrays, new Random(seed), dimensions);

    //As per CpuNDArrayFactory.shuffle(List<INDArray> arrays, Random rnd, List<int[]> dimensions) and libnd4j transforms.h shuffleKernelGeneric
    if (exampleMetaData != null) {
        int[] map = ArrayUtil.buildInterleavedVector(new Random(seed), numExamples());
        ArrayUtil.shuffleWithMap(exampleMetaData, map);
    }
}
 
Example 2
Source File: ShufflesTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testSymmetricShuffle2() {
        INDArray features = Nd4j.zeros(10, 10, 20);
        INDArray labels = Nd4j.zeros(10, 10, 3);

        for (int x = 0; x < 10; x++) {
            features.slice(x).assign(x);
            labels.slice(x).assign(x);
        }

//        System.out.println(features);

        OrderScanner3D scannerFeatures = new OrderScanner3D(features);
        OrderScanner3D scannerLabels = new OrderScanner3D(labels);

        List<INDArray> list = new ArrayList<>();
        list.add(features);
        list.add(labels);

        Nd4j.shuffle(list, 1, 2);

//        System.out.println(features);
//        System.out.println("------------------");
//        System.out.println(labels);

        assertTrue(scannerFeatures.compareSlice(features));
        assertTrue(scannerLabels.compareSlice(labels));

        for (int x = 0; x < 10; x++) {
            double val = features.slice(x).getDouble(0);
            INDArray row = labels.slice(x);

            for (int y = 0; y < row.length(); y++) {
                assertEquals(val, row.getDouble(y), 0.001);
            }
        }
    }
 
Example 3
Source File: ShufflesTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testSymmetricShuffle1() {
        INDArray features = Nd4j.zeros(10, 10);
        INDArray labels = Nd4j.zeros(10, 3);
        for (int x = 0; x < 10; x++) {
            features.getRow(x).assign(x);
            labels.getRow(x).assign(x);
        }
//        System.out.println(features);

        OrderScanner2D scanner = new OrderScanner2D(features);

        assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f);

        List<INDArray> list = new ArrayList<>();
        list.add(features);
        list.add(labels);

        Nd4j.shuffle(list, 1);

//        System.out.println(features);
//        System.out.println();
//        System.out.println(labels);

        ArrayUtil.argMin(new int[] {});

        assertTrue(scanner.compareRow(features));

        for (int x = 0; x < 10; x++) {
            double val = features.getRow(x).getDouble(0);
            INDArray row = labels.getRow(x);

            for (int y = 0; y < row.length(); y++) {
                assertEquals(val, row.getDouble(y), 0.001);
            }
        }
    }
 
Example 4
Source File: ShufflesTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testSimpleShuffle3() {
        INDArray array = Nd4j.zeros(11, 10);
        for (int x = 0; x < 11; x++) {
            array.getRow(x).assign(x);
        }

//        System.out.println(array);
        OrderScanner2D scanner = new OrderScanner2D(array);

        assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f}, scanner.getMap(), 0.01f);
        Nd4j.shuffle(array, 1);
//        System.out.println(array);
        assertTrue(scanner.compareRow(array));
    }
 
Example 5
Source File: ShufflesTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testSimpleShuffle2() {
        INDArray array = Nd4j.zeros(10, 10);
        for (int x = 0; x < 10; x++) {
            array.getColumn(x).assign(x);
        }
//        System.out.println(array);

        OrderScanner2D scanner = new OrderScanner2D(array);
        assertArrayEquals(new float[] {0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f}, scanner.getMap(), 0.01f);
        Nd4j.shuffle(array, 0);
//        System.out.println(array);
        assertTrue(scanner.compareColumn(array));
    }
 
Example 6
Source File: DataSet.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Shuffles the dataset in place, given a seed for a random number generator. For reproducibility
 * This will modify the dataset in place!!
 *
 * @param seed Seed to use for the random Number Generator
 */
public void shuffle(long seed) {
    // just skip shuffle if there's only 1 example
    if (numExamples() < 2)
        return;

    //note here we use the same seed with different random objects guaranteeing same order

    List<INDArray> arrays = new ArrayList<>();
    List<int[]> dimensions = new ArrayList<>();

    arrays.add(getFeatures());
    dimensions.add(ArrayUtil.range(1, getFeatures().rank()));

    arrays.add(getLabels());
    dimensions.add(ArrayUtil.range(1, getLabels().rank()));

    if (featuresMask != null) {
        arrays.add(getFeaturesMaskArray());
        dimensions.add(ArrayUtil.range(1, getFeaturesMaskArray().rank()));
    }

    if (labelsMask != null) {
        arrays.add(getLabelsMaskArray());
        dimensions.add(ArrayUtil.range(1, getLabelsMaskArray().rank()));
    }

    Nd4j.shuffle(arrays, new Random(seed), dimensions);

    //As per CpuNDArrayFactory.shuffle(List<INDArray> arrays, Random rnd, List<int[]> dimensions) and libnd4j transforms.h shuffleKernelGeneric
    if (exampleMetaData != null) {
        int[] map = ArrayUtil.buildInterleavedVector(new Random(seed), numExamples());
        ArrayUtil.shuffleWithMap(exampleMetaData, map);
    }
}
 
Example 7
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSymmetricShuffle2() throws Exception {
    INDArray features = Nd4j.zeros(10, 10, 20);
    INDArray labels = Nd4j.zeros(10, 10, 3);

    for (int x = 0; x < 10; x++) {
        features.slice(x).assign(x);
        labels.slice(x).assign(x);
    }

    System.out.println(features);

    OrderScanner3D scannerFeatures = new OrderScanner3D(features);
    OrderScanner3D scannerLabels = new OrderScanner3D(labels);

    System.out.println();

    List<INDArray> list = new ArrayList<>();
    list.add(features);
    list.add(labels);

    Nd4j.shuffle(list, 1, 2);

    System.out.println(features);

    System.out.println("------------------");

    System.out.println(labels);

    assertTrue(scannerFeatures.compareSlice(features));
    assertTrue(scannerLabels.compareSlice(labels));

    for (int x = 0; x < 10; x++) {
        double val = features.slice(x).getDouble(0);
        INDArray row = labels.slice(x);

        for (int y = 0; y < row.length(); y++) {
            assertEquals(val, row.getDouble(y), 0.001);
        }
    }
}
 
Example 8
Source File: ShufflesTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSymmetricShuffle3() {
    INDArray features = Nd4j.zeros(10, 10, 20);
    INDArray featuresMask = Nd4j.zeros(10, 20);
    INDArray labels = Nd4j.zeros(10, 10, 3);
    INDArray labelsMask = Nd4j.zeros(10, 3);

    for (int x = 0; x < 10; x++) {
        features.slice(x).assign(x);
        featuresMask.slice(x).assign(x);
        labels.slice(x).assign(x);
        labelsMask.slice(x).assign(x);
    }

    OrderScanner3D scannerFeatures = new OrderScanner3D(features);
    OrderScanner3D scannerLabels = new OrderScanner3D(labels);
    OrderScanner3D scannerFeaturesMask = new OrderScanner3D(featuresMask);
    OrderScanner3D scannerLabelsMask = new OrderScanner3D(labelsMask);


    List<INDArray> arrays = new ArrayList<>();
    arrays.add(features);
    arrays.add(labels);
    arrays.add(featuresMask);
    arrays.add(labelsMask);

    List<int[]> dimensions = new ArrayList<>();
    dimensions.add(ArrayUtil.range(1, features.rank()));
    dimensions.add(ArrayUtil.range(1, labels.rank()));
    dimensions.add(ArrayUtil.range(1, featuresMask.rank()));
    dimensions.add(ArrayUtil.range(1, labelsMask.rank()));

    Nd4j.shuffle(arrays, new Random(11), dimensions);

    assertTrue(scannerFeatures.compareSlice(features));
    assertTrue(scannerLabels.compareSlice(labels));
    assertTrue(scannerFeaturesMask.compareSlice(featuresMask));
    assertTrue(scannerLabelsMask.compareSlice(labelsMask));


    for (int x = 0; x < 10; x++) {
        double val = features.slice(x).getDouble(0);
        INDArray sliceLabels = labels.slice(x);
        INDArray sliceLabelsMask = labelsMask.slice(x);
        INDArray sliceFeaturesMask = featuresMask.slice(x);

        for (int y = 0; y < sliceLabels.length(); y++) {
            assertEquals(val, sliceLabels.getDouble(y), 0.001);
        }

        for (int y = 0; y < sliceLabelsMask.length(); y++) {
            assertEquals(val, sliceLabelsMask.getDouble(y), 0.001);
        }

        for (int y = 0; y < sliceFeaturesMask.length(); y++) {
            assertEquals(val, sliceFeaturesMask.getDouble(y), 0.001);
        }
    }
}
 
Example 9
Source File: ShufflesTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testSimpleShuffle1() {
        INDArray array = Nd4j.zeros(10, 10);
        for (int x = 0; x < 10; x++) {
            array.getRow(x).assign(x);
        }

//        System.out.println(array);

        OrderScanner2D scanner = new OrderScanner2D(array);

        assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f);

        Nd4j.shuffle(array, 1);

//        System.out.println(array);

        ArrayUtil.argMin(new int[] {});

        assertTrue(scanner.compareRow(array));
    }
 
Example 10
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSymmetricShuffle3() throws Exception {
    INDArray features = Nd4j.zeros(10, 10, 20);
    INDArray featuresMask = Nd4j.zeros(10, 20);
    INDArray labels = Nd4j.zeros(10, 10, 3);
    INDArray labelsMask = Nd4j.zeros(10, 3);

    for (int x = 0; x < 10; x++) {
        features.slice(x).assign(x);
        featuresMask.slice(x).assign(x);
        labels.slice(x).assign(x);
        labelsMask.slice(x).assign(x);
    }

    OrderScanner3D scannerFeatures = new OrderScanner3D(features);
    OrderScanner3D scannerLabels = new OrderScanner3D(labels);
    OrderScanner3D scannerFeaturesMask = new OrderScanner3D(featuresMask);
    OrderScanner3D scannerLabelsMask = new OrderScanner3D(labelsMask);


    List<INDArray> arrays = new ArrayList<>();
    arrays.add(features);
    arrays.add(labels);
    arrays.add(featuresMask);
    arrays.add(labelsMask);

    List<int[]> dimensions = new ArrayList<>();
    dimensions.add(ArrayUtil.range(1, features.rank()));
    dimensions.add(ArrayUtil.range(1, labels.rank()));
    dimensions.add(ArrayUtil.range(1, featuresMask.rank()));
    dimensions.add(ArrayUtil.range(1, labelsMask.rank()));

    Nd4j.shuffle(arrays, new Random(11), dimensions);

    assertTrue(scannerFeatures.compareSlice(features));
    assertTrue(scannerLabels.compareSlice(labels));
    assertTrue(scannerFeaturesMask.compareSlice(featuresMask));
    assertTrue(scannerLabelsMask.compareSlice(labelsMask));


    for (int x = 0; x < 10; x++) {
        double val = features.slice(x).getDouble(0);
        INDArray sliceLabels = labels.slice(x);
        INDArray sliceLabelsMask = labelsMask.slice(x);
        INDArray sliceFeaturesMask = featuresMask.slice(x);

        for (int y = 0; y < sliceLabels.length(); y++) {
            assertEquals(val, sliceLabels.getDouble(y), 0.001);
        }

        for (int y = 0; y < sliceLabelsMask.length(); y++) {
            assertEquals(val, sliceLabelsMask.getDouble(y), 0.001);
        }

        for (int y = 0; y < sliceFeaturesMask.length(); y++) {
            assertEquals(val, sliceFeaturesMask.getDouble(y), 0.001);
        }
    }
}
 
Example 11
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSimpleShuffle1() {
    INDArray array = Nd4j.zeros(10, 10);
    for (int x = 0; x < 10; x++) {
        array.getRow(x).assign(x);
    }

    System.out.println(array);

    OrderScanner2D scanner = new OrderScanner2D(array);

    assertArrayEquals(new float[]{0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f);

    System.out.println();

    Nd4j.shuffle(array, 1);

    System.out.println(array);

    ArrayUtil.argMin(new int[]{});

    assertTrue(scanner.compareRow(array));
}
 
Example 12
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSymmetricShuffle1() {
    INDArray features = Nd4j.zeros(10, 10);
    INDArray labels = Nd4j.zeros(10, 3);
    for (int x = 0; x < 10; x++) {
        features.getRow(x).assign(x);
        labels.getRow(x).assign(x);
    }

    System.out.println(features);

    OrderScanner2D scanner = new OrderScanner2D(features);

    assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f);

    System.out.println();

    List<INDArray> list = new ArrayList<>();
    list.add(features);
    list.add(labels);

    Nd4j.shuffle(list, 1);

    System.out.println(features);

    System.out.println();

    System.out.println(labels);

    ArrayUtil.argMin(new int[] {});

    assertTrue(scanner.compareRow(features));

    for (int x = 0; x < 10; x++) {
        double val = features.getRow(x).getDouble(0);
        INDArray row = labels.getRow(x);

        for (int y = 0; y < row.length(); y++) {
            assertEquals(val, row.getDouble(y), 0.001);
        }
    }
}
 
Example 13
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSimpleShuffle3() {
    INDArray array = Nd4j.zeros(11, 10);
    for (int x = 0; x < 11; x++) {
        array.getRow(x).assign(x);
    }

    System.out.println(array);

    OrderScanner2D scanner = new OrderScanner2D(array);

    assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f}, scanner.getMap(), 0.01f);

    System.out.println();

    Nd4j.shuffle(array, 1);

    System.out.println(array);

    ArrayUtil.argMin(new int[] {});

    assertTrue(scanner.compareRow(array));
}
 
Example 14
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSimpleShuffle2() {
    INDArray array = Nd4j.zeros(10, 10);
    for (int x = 0; x < 10; x++) {
        array.getColumn(x).assign(x);
    }

    System.out.println(array);

    OrderScanner2D scanner = new OrderScanner2D(array);

    assertArrayEquals(new float[] {0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f}, scanner.getMap(), 0.01f);

    System.out.println();

    Nd4j.shuffle(array, 0);

    System.out.println(array);

    assertTrue(scanner.compareColumn(array));
}
 
Example 15
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSimpleShuffle1() {
    INDArray array = Nd4j.zeros(10, 10);
    for (int x = 0; x < 10; x++) {
        array.getRow(x).assign(x);
    }

    System.out.println(array);

    OrderScanner2D scanner = new OrderScanner2D(array);

    assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f);

    System.out.println();

    Nd4j.shuffle(array, 1);

    System.out.println(array);

    ArrayUtil.argMin(new int[] {});

    assertTrue(scanner.compareRow(array));
}
 
Example 16
Source File: ShufflesTest.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSimpleShuffle1() {
    INDArray array = Nd4j.zeros(10, 10);
    for (int x = 0; x < 10; x++) {
        array.getRow(x).assign(x);
    }

    System.out.println(array);


    System.out.println();

    Nd4j.shuffle(array, 1);

    System.out.println(array);

}
 
Example 17
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSymmetricShuffle3() throws Exception {
    INDArray features = Nd4j.zeros(10, 10, 20);
    INDArray featuresMask = Nd4j.zeros(10, 20);
    INDArray labels = Nd4j.zeros(10, 10, 3);
    INDArray labelsMask = Nd4j.zeros(10, 3);

    for (int x = 0; x < 10; x++) {
        features.slice(x).assign(x);
        featuresMask.slice(x).assign(x);
        labels.slice(x).assign(x);
        labelsMask.slice(x).assign(x);
    }

    OrderScanner3D scannerFeatures = new OrderScanner3D(features);
    OrderScanner3D scannerLabels = new OrderScanner3D(labels);
    OrderScanner3D scannerFeaturesMask = new OrderScanner3D(featuresMask);
    OrderScanner3D scannerLabelsMask = new OrderScanner3D(labelsMask);


    List<INDArray> arrays = new ArrayList<>();
    arrays.add(features);
    arrays.add(labels);
    arrays.add(featuresMask);
    arrays.add(labelsMask);

    List<int[]> dimensions = new ArrayList<>();
    dimensions.add(ArrayUtil.range(1,features.rank()));
    dimensions.add(ArrayUtil.range(1,labels.rank()));
    dimensions.add(ArrayUtil.range(1,featuresMask.rank()));
    dimensions.add(ArrayUtil.range(1,labelsMask.rank()));

    Nd4j.shuffle(arrays, new Random(), dimensions);

    assertTrue(scannerFeatures.compareSlice(features));
    assertTrue(scannerLabels.compareSlice(labels));
    assertTrue(scannerFeaturesMask.compareSlice(featuresMask));
    assertTrue(scannerLabelsMask.compareSlice(labelsMask));


    for (int x = 0; x < 10; x++) {
        double val = features.slice(x).getDouble(0);
        INDArray sliceLabels = labels.slice(x);
        INDArray sliceLabelsMask = labelsMask.slice(x);
        INDArray sliceFeaturesMask = featuresMask.slice(x);

        for (int y = 0; y < sliceLabels.length(); y++ ) {
            assertEquals(val, sliceLabels.getDouble(y), 0.001);
        }

        for (int y = 0; y < sliceLabelsMask.length(); y++ ) {
            assertEquals(val, sliceLabelsMask.getDouble(y), 0.001);
        }

        for (int y = 0; y < sliceFeaturesMask.length(); y++ ) {
            assertEquals(val, sliceFeaturesMask.getDouble(y), 0.001);
        }
    }
}
 
Example 18
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSymmetricShuffle2() throws Exception {
    INDArray features = Nd4j.zeros(10, 10, 20);
    INDArray labels = Nd4j.zeros(10, 10, 3);

    for (int x = 0; x < 10; x++) {
        features.slice(x).assign(x);
        labels.slice(x).assign(x);
    }

    System.out.println(features);

    OrderScanner3D scannerFeatures = new OrderScanner3D(features);
    OrderScanner3D scannerLabels = new OrderScanner3D(labels);

    System.out.println();

    List<INDArray> list = new ArrayList<>();
    list.add(features);
    list.add(labels);

    Nd4j.shuffle(list, 1, 2);

    System.out.println(features);

    System.out.println("------------------");

    System.out.println(labels);

    assertTrue(scannerFeatures.compareSlice(features));
    assertTrue(scannerLabels.compareSlice(labels));

    for (int x = 0; x < 10; x++) {
        double val = features.slice(x).getDouble(0);
        INDArray row = labels.slice(x);

        for (int y = 0; y < row.length(); y++ ) {
            assertEquals(val, row.getDouble(y), 0.001);
        }
    }
}
 
Example 19
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSymmetricShuffle1() {
    INDArray features = Nd4j.zeros(10, 10);
    INDArray labels = Nd4j.zeros(10, 3);
    for (int x = 0; x < 10; x++) {
        features.getRow(x).assign(x);
        labels.getRow(x).assign(x);
    }

    System.out.println(features);

    OrderScanner2D scanner = new OrderScanner2D(features);

    assertArrayEquals(new float[]{0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f);

    System.out.println();

    List<INDArray> list = new ArrayList<>();
    list.add(features);
    list.add(labels);

    Nd4j.shuffle(list, 1);

    System.out.println(features);

    System.out.println();

    System.out.println(labels);

    ArrayUtil.argMin(new int[]{});

    assertTrue(scanner.compareRow(features));

    for (int x = 0; x < 10; x++) {
        double val = features.getRow(x).getDouble(0);
        INDArray row = labels.getRow(x);

        for (int y = 0; y < row.length(); y++ ) {
            assertEquals(val, row.getDouble(y), 0.001);
        }
    }
}
 
Example 20
Source File: ShufflesTests.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testSimpleShuffle2() {
    INDArray array = Nd4j.zeros(10, 10);
    for (int x = 0; x < 10; x++) {
        array.getColumn(x).assign(x);
    }

    System.out.println(array);

    OrderScanner2D scanner = new OrderScanner2D(array);

    assertArrayEquals(new float[]{0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f}, scanner.getMap(), 0.01f);

    System.out.println();

    Nd4j.shuffle(array, 0);

    System.out.println(array);

    assertTrue(scanner.compareColumn(array));
}