Java Code Examples for org.nd4j.linalg.api.rng.Random#setSeed()

The following examples show how to use org.nd4j.linalg.api.rng.Random#setSeed() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: RandomTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSetSeed1() throws Exception {
    Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
    Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119);

    INDArray z01 = Nd4j.create(1000);
    INDArray z11 = Nd4j.create(1000);

    UniformDistribution distribution01 = new UniformDistribution(z01, 1.0, 2.0);
    Nd4j.getExecutioner().exec(distribution01, random1);

    UniformDistribution distribution11 = new UniformDistribution(z11, 1.0, 2.0);
    Nd4j.getExecutioner().exec(distribution11, random2);

    random1.setSeed(1999);
    random2.setSeed(1999);

    INDArray z02 = Nd4j.create(100);
    UniformDistribution distribution02 = new UniformDistribution(z02, 1.0, 2.0);
    Nd4j.getExecutioner().exec(distribution02, random1);

    INDArray z12 = Nd4j.create(100);
    UniformDistribution distribution12 = new UniformDistribution(z12, 1.0, 2.0);
    Nd4j.getExecutioner().exec(distribution12, random2);


    for (int x = 0; x < z01.length(); x++) {
        assertEquals("Failed on element: [" + x + "]", z01.getFloat(x), z11.getFloat(x), 0.01f);
    }

    assertEquals(z01, z11);

    for (int x = 0; x < z02.length(); x++) {
        assertEquals("Failed on element: [" + x + "]", z02.getFloat(x), z12.getFloat(x), 0.01f);
    }

    assertEquals(z02, z12);
}
 
Example 2
Source File: Nd4jTest.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testGetRandomSetSeed() {
    Random r = Nd4j.getRandom();
    Random t = Nd4j.getRandom();

    assertEquals(r, t);
    r.setSeed(123);
    assertEquals(r, t);
}
 
Example 3
Source File: RandomFactory.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method returns new onject implementing Random interface, initialized with seed value
 *
 * @param seed seed for this rng object
 * @return object implementing Random interface
 */
public Random getNewRandomInstance(long seed) {
    try {
        Random t = (Random) randomClass.newInstance();
        if (t.getStatePointer() != null) {
            // TODO: attach this thing to deallocator
            // if it's stateless random - we just don't care then
        }
        t.setSeed(seed);
        return t;
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
 
Example 4
Source File: A3CDiscrete.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public A3CDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
    this.iActorCritic = iActorCritic;
    this.mdp = mdp;
    this.configuration = conf;
    asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);

    Long seed = conf.getSeed();
    Random rnd = Nd4j.getRandom();
    if (seed != null) {
        rnd.setSeed(seed);
    }

    policy = new ACPolicy<>(iActorCritic, rnd);
}
 
Example 5
Source File: RandomTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSetSeed1() {
    Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
    Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119);

    INDArray z01 = Nd4j.create(1000);
    INDArray z11 = Nd4j.create(1000);

    UniformDistribution distribution01 = new UniformDistribution(z01, 1.0, 2.0);
    Nd4j.getExecutioner().exec(distribution01, random1);

    UniformDistribution distribution11 = new UniformDistribution(z11, 1.0, 2.0);
    Nd4j.getExecutioner().exec(distribution11, random2);

    random1.setSeed(1999);
    random2.setSeed(1999);

    INDArray z02 = Nd4j.create(100);
    UniformDistribution distribution02 = new UniformDistribution(z02, 1.0, 2.0);
    Nd4j.getExecutioner().exec(distribution02, random1);

    INDArray z12 = Nd4j.create(100);
    UniformDistribution distribution12 = new UniformDistribution(z12, 1.0, 2.0);
    Nd4j.getExecutioner().exec(distribution12, random2);


    for (int x = 0; x < z01.length(); x++) {
        assertEquals("Failed on element: [" + x + "]", z01.getFloat(x), z11.getFloat(x), 0.01f);
    }

    assertEquals(z01, z11);

    for (int x = 0; x < z02.length(); x++) {
        assertEquals("Failed on element: [" + x + "]", z02.getFloat(x), z12.getFloat(x), 0.01f);
    }

    assertEquals(z02, z12);
}
 
Example 6
Source File: Nd4jTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testGetRandomSetSeed() {
    Random r = Nd4j.getRandom();
    Random t = Nd4j.getRandom();

    assertEquals(r, t);
    r.setSeed(123);
    assertEquals(r, t);
}
 
Example 7
Source File: SortCooTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void sortSparseCooIndicesSort3() {
    // FIXME: we don't want this test running on cuda for now
    if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda"))
        return;

    Random rng = Nd4j.getRandom();
    rng.setSeed(12040483421383L);
    long shape[] = {50,50,50};
    int nnz = 100;
    val indices = Nd4j.rand(new int[]{nnz, shape.length}, rng).muli(50).ravel().toLongVector();
    val values = Nd4j.rand(new long[]{nnz}).ravel().toDoubleVector();


    DataBuffer indiceBuffer = Nd4j.getDataBufferFactory().createLong(indices);
    DataBuffer valueBuffer = Nd4j.createBuffer(values);
    INDArray indMatrix = Nd4j.create(indiceBuffer).reshape(new long[]{nnz, shape.length});

    NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) indiceBuffer.addressPointer(),
            valueBuffer.addressPointer(), nnz, 3);

    for (long i = 1; i < nnz; ++i){
        for(long j = 0; j < shape.length; ++j){
            long prev = indiceBuffer.getLong(((i - 1) * shape.length + j));
            long current = indiceBuffer.getLong((i * shape.length + j));
            if (prev < current){
                break;
            } else if(prev > current){
                long[] prevRow = getLongsAt(indiceBuffer, (i - 1) * shape.length, shape.length);
                long[] currentRow = getLongsAt(indiceBuffer, i * shape.length, shape.length);
                throw new AssertionError(String.format("indices are not correctly sorted between element %d and %d. %s > %s",
                        i - 1, i, Arrays.toString(prevRow), Arrays.toString(currentRow)));
            }
        }
    }
}
 
Example 8
Source File: RandomFactory.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * This method returns new onject implementing Random interface, initialized with seed value
 *
 * @param seed seed for this rng object
 * @return object implementing Random interface
 */
public Random getNewRandomInstance(long seed) {
    try {
        Random t = (Random) randomClass.newInstance();
        if (t.getStatePointer() != null) {
            // TODO: attach this thing to deallocator
            // if it's stateless random - we just don't care then
        }
        t.setSeed(seed);
        return t;
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
 
Example 9
Source File: RandomTests.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testStepOver3() {
    Random random = Nd4j.getRandomFactory().getNewRandomInstance(119);
    if (random instanceof NativeRandom) {
        NativeRandom rng = (NativeRandom) random;

        int someInt = rng.nextInt();
        for (int e = 0; e < 10000; e++)
            rng.nextInt();

        random.setSeed(119);

        int sameInt = rng.nextInt();

        assertEquals(someInt, sameInt);

        random.setSeed(120);

        int otherInt = rng.nextInt();

        assertNotEquals(someInt, otherInt);

    } else
        log.warn("Not a NativeRandom object received, skipping test");
}
 
Example 10
Source File: RandomTests.java    From nd4j with Apache License 2.0 3 votes vote down vote up
@Test
public void testStepOver3() throws Exception {
    Random random = Nd4j.getRandomFactory().getNewRandomInstance(119);
    if (random instanceof NativeRandom) {
        NativeRandom rng = (NativeRandom) random;
        assertTrue(rng.getBufferSize() > 1000000L);

        int someInt = rng.nextInt();
        for (int e = 0; e < 10000; e++)
            rng.nextInt();

        random.setSeed(119);

        int sameInt = rng.nextInt();

        assertEquals(someInt, sameInt);

        random.setSeed(120);

        int otherInt = rng.nextInt();

        assertNotEquals(someInt, otherInt);


    } else
        log.warn("Not a NativeRandom object received, skipping test");
}
 
Example 11
Source File: RandomTests.java    From nd4j with Apache License 2.0 2 votes vote down vote up
@Test
public void testStepOver1() throws Exception {
    Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);


    log.info("1: ----------------");

    INDArray z0 = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(1000000), 0.0, 1.0));

    assertEquals(0.0, z0.meanNumber().doubleValue(), 0.01);
    assertEquals(1.0, z0.stdNumber().doubleValue(), 0.01);

    random1.setSeed(119);

    log.info("2: ----------------");

    INDArray z2 = Nd4j.zeros(55000000);
    INDArray z1 = Nd4j.zeros(55000000);

    GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
    Nd4j.getExecutioner().exec(op1, random1);


    log.info("2: ----------------");

    //log.info("End: [{}, {}, {}, {}]", z1.getFloat(29000000), z1.getFloat(29000001), z1.getFloat(29000002), z1.getFloat(29000003));

    log.info("Sum: {}", z1.sumNumber().doubleValue());
    log.info("Sum2: {}", z2.sumNumber().doubleValue());


    INDArray match = Nd4j.getExecutioner().exec(new MatchCondition(z1, Conditions.isNan()), Integer.MAX_VALUE);
    log.info("NaNs: {}", match);
    assertEquals(0.0f, match.getFloat(0), 0.01f);

    /*
    for (int i = 0; i < z1.length(); i++) {
        if (Double.isNaN(z1.getDouble(i)))
            throw new IllegalStateException("NaN value found at " + i);
    
        if (Double.isInfinite(z1.getDouble(i)))
            throw new IllegalStateException("Infinite value found at " + i);
    }
    */

    assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
    assertEquals(0.0, z1.meanNumber().doubleValue(), 0.01);
}
 
Example 12
Source File: RandomTests.java    From deeplearning4j with Apache License 2.0 2 votes vote down vote up
@Test
    public void testStepOver1() {
        Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);


//        log.info("1: ----------------");

        INDArray z0 = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(DataType.DOUBLE, 1000000), 0.0, 1.0));

        assertEquals(0.0, z0.meanNumber().doubleValue(), 0.01);
        assertEquals(1.0, z0.stdNumber().doubleValue(), 0.01);

        random1.setSeed(119);

//        log.info("2: ----------------");

        INDArray z1 = Nd4j.zeros(DataType.DOUBLE, 55000000);
        INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);

        GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
        Nd4j.getExecutioner().exec(op1, random1);

//        log.info("2: ----------------");

        //log.info("End: [{}, {}, {}, {}]", z1.getFloat(29000000), z1.getFloat(29000001), z1.getFloat(29000002), z1.getFloat(29000003));

        //log.info("Sum: {}", z1.sumNumber().doubleValue());
//        log.info("Sum2: {}", z2.sumNumber().doubleValue());


        INDArray match = Nd4j.getExecutioner().exec(new MatchCondition(z1, Conditions.isNan()));
//        log.info("NaNs: {}", match);
        assertEquals(0.0f, match.getFloat(0), 0.01f);

        /*
        for (int i = 0; i < z1.length(); i++) {
            if (Double.isNaN(z1.getDouble(i)))
                throw new IllegalStateException("NaN value found at " + i);
        
            if (Double.isInfinite(z1.getDouble(i)))
                throw new IllegalStateException("Infinite value found at " + i);
        }
        */

        assertEquals(0.0, z1.meanNumber().doubleValue(), 0.01);
        assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
    }