org.nd4j.linalg.learning.config.Nadam Java Examples

Example #1
Source File:    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public IUpdater getValue(double[] parameterValues) {
    double lr = learningRate == null ? Nadam.DEFAULT_NADAM_LEARNING_RATE : learningRate.getValue(parameterValues);
    ISchedule lrS = learningRateSchedule == null ? null : learningRateSchedule.getValue(parameterValues);
    double b1 = beta1 == null ? Nadam.DEFAULT_NADAM_LEARNING_RATE : beta1.getValue(parameterValues);
    double b2 = beta2 == null ? Nadam.DEFAULT_NADAM_LEARNING_RATE : beta2.getValue(parameterValues);
    double eps = epsilon == null ? Nadam.DEFAULT_NADAM_LEARNING_RATE : epsilon.getValue(parameterValues);
    if(lrS == null){
        return new Nadam(lr, b1, b2, eps);
    } else {
        Nadam a = new Nadam(lrS);
        return a;
Example #2
Source File:    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
protected GradientUpdater<?> getOldFunction(long[] shape) {
    Nadam configuration = new Nadam();
    GradientUpdater<?> oldFunction = new NadamUpdater(configuration);
    int length = (int) (shape[0] * configuration.stateSize(shape[1]));
    INDArray view = Nd4j.zeros(length);
    oldFunction.setStateViewArray(view, shape, 'c', true);
    return oldFunction;
Example #3
Source File:    From nd4j with Apache License 2.0 4 votes vote down vote up
public NadamUpdater(Nadam config) {
    this.config = config;
Example #4
Source File:    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public NadamUpdater(Nadam config) {
    this.config = config;
Example #5
Source File:    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public void testSerializationConfigurations() {

    SerializerInstance si = sc.env().serializer().newInstance();

    //Check network configurations:
    Map<Integer, Double> m = new HashMap<>();
    m.put(0, 0.5);
    m.put(10, 0.1);
    MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder()
                    .updater(new Nadam(new MapSchedule(ScheduleType.ITERATION,m))).list().layer(0, new OutputLayer.Builder().nIn(10).nOut(10).build())

    testSerialization(mlc, si);

    ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder()
                    .dist(new UniformDistribution(-1, 1))
                    .updater(new Adam(new MapSchedule(ScheduleType.ITERATION,m)))
                    .addInputs("in").addLayer("out", new OutputLayer.Builder().nIn(10).nOut(10).build(), "in")

    testSerialization(cgc, si);

    //Check main layers:
    Layer[] layers = new Layer[] {new OutputLayer.Builder().nIn(10).nOut(10).build(),
                    new RnnOutputLayer.Builder().nIn(10).nOut(10).build(), new LossLayer.Builder().build(),
                    new CenterLossOutputLayer.Builder().nIn(10).nOut(10).build(),
                    new DenseLayer.Builder().nIn(10).nOut(10).build(),
                    new ConvolutionLayer.Builder().nIn(10).nOut(10).build(), new SubsamplingLayer.Builder().build(),
                    new Convolution1DLayer.Builder(2, 2).nIn(10).nOut(10).build(),
                    new ActivationLayer.Builder().activation(Activation.TANH).build(),
                    new GlobalPoolingLayer.Builder().build(), new GravesLSTM.Builder().nIn(10).nOut(10).build(),
                    new LSTM.Builder().nIn(10).nOut(10).build(), new DropoutLayer.Builder(0.5).build(),
                    new BatchNormalization.Builder().build(), new LocalResponseNormalization.Builder().build()};

    for (Layer l : layers) {
        testSerialization(l, si);

    //Check graph vertices
    GraphVertex[] vertices = new GraphVertex[] {new ElementWiseVertex(ElementWiseVertex.Op.Add),
                    new L2NormalizeVertex(), new LayerVertex(null, null), new MergeVertex(), new PoolHelperVertex(),
                    new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 1)),
                    new ReshapeVertex(new int[] {1, 1}), new ScaleVertex(1.0), new ShiftVertex(1.0),
                    new SubsetVertex(1, 1), new UnstackVertex(0, 2), new DuplicateToTimeSeriesVertex("in1"),
                    new LastTimeStepVertex("in1")};

    for (GraphVertex gv : vertices) {
        testSerialization(gv, si);