org.nd4j.linalg.api.rng.distribution.Distribution Java Examples

The following examples show how to use org.nd4j.linalg.api.rng.distribution.Distribution. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: UpdaterTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testNesterovs() {
    int rows = 10;
    int cols = 2;


    NesterovsUpdater grad = new NesterovsUpdater(new Nesterovs(0.5, 0.9));
    grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #2
Source File: UpdaterTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAdaGrad() {
    int rows = 10;
    int cols = 2;


    AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
    grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }

}
 
Example #3
Source File: UpdaterTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAdaDelta() {
    int rows = 10;
    int cols = 2;


    AdaDeltaUpdater grad = new AdaDeltaUpdater(new AdaDelta());
    grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdaelta\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #4
Source File: UpdaterTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAdam() {
    int rows = 10;
    int cols = 2;


    AdamUpdater grad = new AdamUpdater(new Adam());
    grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdamUpdater\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #5
Source File: UpdaterTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testNadam() {
    int rows = 10;
    int cols = 2;

    NadamUpdater grad = new NadamUpdater(new Nadam());
    grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdamUpdater\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #6
Source File: ProbabilityTestCase.java    From jstarcraft-ai with Apache License 2.0 6 votes vote down vote up
@Test
public void testSample() {
    int seed = 1000;

    Distribution oldFunction = getOldFunction(seed);
    MathProbability<Number> newFuction = getNewFunction(seed);

    for (int index = 0; index < seed; index++) {
        newFuction.setSeed(index);
        oldFunction.reseedRandomGenerator(index);
        assertSample(newFuction, oldFunction);
    }

    Assert.assertThat(newFuction.getMaximum().doubleValue(), CoreMatchers.equalTo(oldFunction.getSupportUpperBound()));
    Assert.assertThat(newFuction.getMinimum().doubleValue(), CoreMatchers.equalTo(oldFunction.getSupportLowerBound()));
    Assert.assertThat(newFuction.inverseDistribution(1D).doubleValue(), CoreMatchers.equalTo(oldFunction.getSupportUpperBound()));
    Assert.assertThat(newFuction.inverseDistribution(0D).doubleValue(), CoreMatchers.equalTo(oldFunction.getSupportLowerBound()));
    Assert.assertThat(newFuction.cumulativeDistribution(newFuction.getMaximum()), CoreMatchers.equalTo(oldFunction.cumulativeProbability(oldFunction.getSupportUpperBound())));
    Assert.assertThat(newFuction.cumulativeDistribution(newFuction.getMinimum()), CoreMatchers.equalTo(oldFunction.cumulativeProbability(oldFunction.getSupportLowerBound())));
}
 
Example #7
Source File: UpdaterTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAdaMax() {
    int rows = 10;
    int cols = 2;


    AdaMaxUpdater grad = new AdaMaxUpdater(new AdaMax());
    grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdaMax\n " + grad.getGradient(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #8
Source File: UpdaterTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAdaMax() {
    int rows = 10;
    int cols = 2;


    AdaMaxUpdater grad = new AdaMaxUpdater(new AdaMax());
    grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdaMax\n " + grad.getGradient(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #9
Source File: UpdaterTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testNadam() {
    int rows = 10;
    int cols = 2;

    NadamUpdater grad = new NadamUpdater(new Nadam());
    grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdamUpdater\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #10
Source File: UpdaterTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAdam() {
    int rows = 10;
    int cols = 2;


    AdamUpdater grad = new AdamUpdater(new Adam());
    grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdamUpdater\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #11
Source File: UpdaterTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAdaDelta() {
    int rows = 10;
    int cols = 2;


    AdaDeltaUpdater grad = new AdaDeltaUpdater(new AdaDelta());
    grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdaelta\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #12
Source File: UpdaterTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testAdaGrad() {
    int rows = 10;
    int cols = 2;


    AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
    grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }

}
 
Example #13
Source File: UpdaterTest.java    From nd4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testNesterovs() {
    int rows = 10;
    int cols = 2;


    NesterovsUpdater grad = new NesterovsUpdater(new Nesterovs(0.5, 0.9));
    grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
    for (int i = 0; i < W.rows(); i++)
        W.putRow(i, Nd4j.create(dist.sample(W.columns())));

    for (int i = 0; i < 5; i++) {
        //            String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        //            System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
 
Example #14
Source File: ShufflesTest.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBinomial() {
    Distribution distribution = Nd4j.getDistributions().createBinomial(3, Nd4j.create(10).putScalar(1, 0.00001));

    for (int x = 0; x < 10000; x++) {
        INDArray z = distribution.sample(new int[]{1, 10});

        System.out.println();

        MatchCondition condition = new MatchCondition(z, Conditions.equals(0.0));
        int match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
        assertEquals(z.length(), match);
    }
}
 
Example #15
Source File: CDAEParameter.java    From jstarcraft-rns with Apache License 2.0 5 votes vote down vote up
private INDArray createUserWeightMatrix(NeuralNetConfiguration conf, INDArray weightParamView, boolean initializeParameters) {
    FeedForwardLayer layerConf = (FeedForwardLayer) conf.getLayer();
    if (initializeParameters) {
        Distribution dist = Distributions.createDistribution(layerConf.getDist());
        return createWeightMatrix(numberOfUsers, layerConf.getNOut(), layerConf.getWeightInit(), dist, weightParamView, true);
    } else {
        return createWeightMatrix(numberOfUsers, layerConf.getNOut(), null, null, weightParamView, false);
    }
}
 
Example #16
Source File: DeepFMParameter.java    From jstarcraft-rns with Apache License 2.0 5 votes vote down vote up
protected INDArray createWeightMatrix(NeuralNetConfiguration configuration, INDArray view, boolean initialize) {
    FeedForwardLayer layerConfiguration = (FeedForwardLayer) configuration.getLayer();
    if (initialize) {
        Distribution distribution = Distributions.createDistribution(layerConfiguration.getDist());
        return super.createWeightMatrix(numberOfFeatures, layerConfiguration.getNOut(), layerConfiguration.getWeightInit(), distribution, view, true);
    } else {
        return super.createWeightMatrix(numberOfFeatures, layerConfiguration.getNOut(), null, null, view, false);
    }
}
 
Example #17
Source File: RandomTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
/**
 * Uses a test of Gaussianity for testing the values out of GaussianDistribution
 * See https://en.wikipedia.org/wiki/Anderson%E2%80%93Darling_test
 *
 * @throws Exception
 */
@Test
public void testAndersonDarling() throws Exception {

    Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
    INDArray z1 = Nd4j.create(1000);

    GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
    Nd4j.getExecutioner().exec(op1, random1);

    val n = z1.length();
    //using this just for the cdf
    Distribution nd = new NormalDistribution(random1, 0.0, 1.0);
    Nd4j.sort(z1, true);

    System.out.println("Data for Anderson-Darling: " + z1);

    for (int i = 0; i < n; i++) {

        Double res = nd.cumulativeProbability(z1.getDouble(i));
        assertTrue (res >= 0.0);
        assertTrue (res <= 1.0);
        // avoid overflow when taking log later.
        if (res == 0) res = 0.0000001;
        if (res == 1) res = 0.9999999;
        z1.putScalar(i, res);
    }

    double A = 0.0;
    for (int i = 0; i < n; i++) {

        A -= (2*i+1) * (Math.log(z1.getDouble(i)) + Math.log(1-z1.getDouble(n - i - 1)));
    }

    A = A / n - n;
    A *= (1 + 4.0/n - 25.0/(n*n));

    assertTrue("Critical (max) value for 1000 points and confidence α = 0.0001 is 1.8692, received: "+ A, A < 1.8692);
}
 
Example #18
Source File: RandomTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
     * Uses a test of Gaussianity for testing the values out of GaussianDistribution
     * See https://en.wikipedia.org/wiki/Anderson%E2%80%93Darling_test
     *
     * @throws Exception
     */
    @Test
    public void testAndersonDarling() {

        Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
        INDArray z1 = Nd4j.create(1000);

        GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
        Nd4j.getExecutioner().exec(op1, random1);

        val n = z1.length();
        //using this just for the cdf
        Distribution nd = new NormalDistribution(random1, 0.0, 1.0);
        Nd4j.sort(z1, true);

//        System.out.println("Data for Anderson-Darling: " + z1);

        for (int i = 0; i < n; i++) {

            Double res = nd.cumulativeProbability(z1.getDouble(i));
            assertTrue (res >= 0.0);
            assertTrue (res <= 1.0);
            // avoid overflow when taking log later.
            if (res == 0) res = 0.0000001;
            if (res == 1) res = 0.9999999;
            z1.putScalar(i, res);
        }

        double A = 0.0;
        for (int i = 0; i < n; i++) {

            A -= (2*i+1) * (Math.log(z1.getDouble(i)) + Math.log(1-z1.getDouble(n - i - 1)));
        }

        A = A / n - n;
        A *= (1 + 4.0/n - 25.0/(n*n));

        assertTrue("Critical (max) value for 1000 points and confidence α = 0.0001 is 1.8692, received: "+ A, A < 1.8692);
    }
 
Example #19
Source File: WeightInitUtil.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public static INDArray initWeights(double fanIn, double fanOut, long[] shape, WeightInit initScheme,
                Distribution dist, char order, INDArray paramView) {
    switch (initScheme) {
        case DISTRIBUTION:
            if (dist instanceof OrthogonalDistribution) {
                dist.sample(paramView.reshape(order, shape));
            } else {
                dist.sample(paramView);
            }
            break;
        case RELU:
            Nd4j.randn(paramView).muli(FastMath.sqrt(2.0 / fanIn)); //N(0, 2/nIn)
            break;
        case RELU_UNIFORM:
            double u = Math.sqrt(6.0 / fanIn);
            Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
            break;
        case SIGMOID_UNIFORM:
            double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
            Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-r, r));
            break;
        case UNIFORM:
            double a = 1.0 / Math.sqrt(fanIn);
            Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-a, a));
            break;
        case LECUN_UNIFORM:
            double b = 3.0 / Math.sqrt(fanIn);
            Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-b, b));
            break;
        case XAVIER:
            Nd4j.randn(paramView).muli(FastMath.sqrt(2.0 / (fanIn + fanOut)));
            break;
        case XAVIER_UNIFORM:
            //As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut))
            //Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
            double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut);
            Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-s, s));
            break;
        case LECUN_NORMAL:  //Fall through: these 3 are equivalent
        case NORMAL:
        case XAVIER_FAN_IN:
            Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
            break;
        case XAVIER_LEGACY:
            Nd4j.randn(paramView).divi(FastMath.sqrt(shape[0] + shape[1]));
            break;
        case ZERO:
            paramView.assign(0.0);
            break;
        case ONES:
            paramView.assign(1.0);
            break;
        case IDENTITY:
            if(shape.length != 2 || shape[0] != shape[1]){
                throw new IllegalStateException("Cannot use IDENTITY init with parameters of shape "
                        + Arrays.toString(shape) + ": weights must be a square matrix for identity");
            }
            INDArray ret;
            if(order == Nd4j.order()){
                ret = Nd4j.eye(shape[0]);
            } else {
                ret = Nd4j.createUninitialized(shape, order).assign(Nd4j.eye(shape[0]));
            }
            INDArray flat = Nd4j.toFlattened(order, ret);
            paramView.assign(flat);
            break;
        case VAR_SCALING_NORMAL_FAN_IN:
            Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanIn)));
            break;
        case VAR_SCALING_NORMAL_FAN_OUT:
            Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanOut)));
            break;
        case VAR_SCALING_NORMAL_FAN_AVG:
            Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(2.0 / (fanIn + fanOut))));
            break;
        case VAR_SCALING_UNIFORM_FAN_IN:
            double scalingFanIn = 3.0 / Math.sqrt(fanIn);
            Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn));
            break;
        case VAR_SCALING_UNIFORM_FAN_OUT:
            double scalingFanOut = 3.0 / Math.sqrt(fanOut);
            Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut));
            break;
        case VAR_SCALING_UNIFORM_FAN_AVG:
            double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2);
            Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg));
            break;
        default:
            throw new IllegalStateException("Illegal weight init value: " + initScheme);
    }

    return paramView.reshape(order, shape);
}
 
Example #20
Source File: WeightInitUtil.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Deprecated
public static INDArray initWeights(double fanIn, double fanOut, int[] shape, WeightInit initScheme,
                                   Distribution dist, char order, INDArray paramView) {
    return initWeights(fanIn, fanOut, ArrayUtil.toLongArray(shape), initScheme, dist, order, paramView);
}
 
Example #21
Source File: DefaultDistributionFactory.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Distribution createUniform(double min, double max) {
    return new UniformDistribution(min, max);
}
 
Example #22
Source File: DefaultDistributionFactory.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Distribution createConstant(double value) {
    return new ConstantDistribution(value);
}
 
Example #23
Source File: DefaultDistributionFactory.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Distribution createOrthogonal(double gain) {
    return new OrthogonalDistribution(gain);
}
 
Example #24
Source File: DefaultDistributionFactory.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Distribution createNormal(INDArray mean, double std) {
    return new NormalDistribution(mean, std);
}
 
Example #25
Source File: DefaultDistributionFactory.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Distribution createNormal(double mean, double std) {
    return new NormalDistribution(mean, std);
}
 
Example #26
Source File: DefaultDistributionFactory.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Distribution createLogNormal(double mean, double std) {
    return new LogNormalDistribution(mean, std);
}
 
Example #27
Source File: DefaultDistributionFactory.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Distribution createTruncatedNormal(double mean, double std) {
    return new TruncatedNormalDistribution(mean, std);
}
 
Example #28
Source File: DistributionInitScheme.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Builder
public DistributionInitScheme(char order, Distribution distribution) {
    super(order);
    this.distribution = distribution;
}
 
Example #29
Source File: DefaultDistributionFactory.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Distribution createBinomial(int n, double p) {
    return new BinomialDistribution(n, p);
}
 
Example #30
Source File: DefaultDistributionFactory.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public Distribution createBinomial(int n, INDArray p) {
    return new BinomialDistribution(n, p);
}