org.nd4j.linalg.learning.GradientUpdater Java Examples

The following examples show how to use org.nd4j.linalg.learning.GradientUpdater. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: LearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 6 votes vote down vote up
@Test
public void testGradient() throws Exception {
    EnvironmentContext context = EnvironmentFactory.getContext();
    Future<?> task = context.doTask(() -> {
        long[] shape = { 5L, 2L };
        INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(shape);
        GradientUpdater<?> oldFunction = getOldFunction(shape);
        DenseMatrix gradient = getMatrix(array);
        Map<String, MathMatrix> gradients = new HashMap<>();
        gradients.put("gradients", gradient);
        Learner newFuction = getNewFunction(shape);
        newFuction.doCache(gradients);

        for (int iteration = 0; iteration < 10; iteration++) {
            oldFunction.applyUpdater(array, iteration, 0);
            newFuction.learn(gradients, iteration, 0);

            System.out.println(array);
            System.out.println(gradients);

            Assert.assertTrue(equalMatrix(gradient, array));
        }
    });
    task.get();
}
 
Example #2
Source File: AdaDeltaLearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Override
protected GradientUpdater<?> getOldFunction(long[] shape) {
    AdaDelta configuration = new AdaDelta();
    GradientUpdater<?> oldFunction = new AdaDeltaUpdater(configuration);
    int length = (int) (shape[0] * configuration.stateSize(shape[1]));
    INDArray view = Nd4j.zeros(length);
    oldFunction.setStateViewArray(view, shape, 'c', true);
    return oldFunction;
}
 
Example #3
Source File: AdaMax.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    AdaMaxUpdater a = new AdaMaxUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 2;
    a.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return a;
}
 
Example #4
Source File: AdaDelta.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    AdaDeltaUpdater u = new AdaDeltaUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 2;
    u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return u;
}
 
Example #5
Source File: Sgd.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    if (viewArray != null) {
        throw new IllegalStateException("View arrays are not supported/required for SGD updater");
    }
    return new SgdUpdater(this);
}
 
Example #6
Source File: Adam.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    AdamUpdater u = new AdamUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 2;
    u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return u;
}
 
Example #7
Source File: NoOp.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    if (viewArray != null) {
        throw new IllegalStateException("Cannot use view array with NoOp updater");
    }
    return new NoOpUpdater(this);
}
 
Example #8
Source File: Nadam.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    NadamUpdater u = new NadamUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 2;
    u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return u;
}
 
Example #9
Source File: AMSGrad.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    AMSGradUpdater u = new AMSGradUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 3;
    u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return u;
}
 
Example #10
Source File: Nadam.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    NadamUpdater u = new NadamUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 2;
    u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return u;
}
 
Example #11
Source File: Sgd.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    if (viewArray != null) {
        throw new IllegalStateException("View arrays are not supported/required for SGD updater");
    }
    return new SgdUpdater(this);
}
 
Example #12
Source File: AMSGrad.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    AMSGradUpdater u = new AMSGradUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 3;
    u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return u;
}
 
Example #13
Source File: AdaMax.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    AdaMaxUpdater a = new AdaMaxUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 2;
    a.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return a;
}
 
Example #14
Source File: Adam.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    AdamUpdater u = new AdamUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 2;
    u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return u;
}
 
Example #15
Source File: NadamLearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Override
protected GradientUpdater<?> getOldFunction(long[] shape) {
    Nadam configuration = new Nadam();
    GradientUpdater<?> oldFunction = new NadamUpdater(configuration);
    int length = (int) (shape[0] * configuration.stateSize(shape[1]));
    INDArray view = Nd4j.zeros(length);
    oldFunction.setStateViewArray(view, shape, 'c', true);
    return oldFunction;
}
 
Example #16
Source File: AdaMaxLearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Override
protected GradientUpdater<?> getOldFunction(long[] shape) {
    AdaMax configuration = new AdaMax();
    GradientUpdater<?> oldFunction = new AdaMaxUpdater(configuration);
    int length = (int) (shape[0] * configuration.stateSize(shape[1]));
    INDArray view = Nd4j.zeros(length);
    oldFunction.setStateViewArray(view, shape, 'c', true);
    return oldFunction;
}
 
Example #17
Source File: NesterovLearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Override
protected GradientUpdater<?> getOldFunction(long[] shape) {
    Nesterovs configuration = new Nesterovs();
    GradientUpdater<?> oldFunction = new NesterovsUpdater(configuration);
    int length = (int) (shape[0] * configuration.stateSize(shape[1]));
    INDArray view = Nd4j.zeros(length);
    oldFunction.setStateViewArray(view, shape, 'c', true);
    return oldFunction;
}
 
Example #18
Source File: RmsPropLearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Override
protected GradientUpdater<?> getOldFunction(long[] shape) {
    RmsProp configuration = new RmsProp();
    GradientUpdater<?> oldFunction = new RmsPropUpdater(configuration);
    int length = (int) (shape[0] * configuration.stateSize(shape[1]));
    INDArray view = Nd4j.zeros(length);
    oldFunction.setStateViewArray(view, shape, 'c', true);
    return oldFunction;
}
 
Example #19
Source File: CustomIUpdater.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    if (viewArray != null) {
        throw new IllegalStateException("View arrays are not supported/required for SGD updater");
    }
    return new CustomGradientUpdater(this);
}
 
Example #20
Source File: AdaGradLearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Override
protected GradientUpdater<?> getOldFunction(long[] shape) {
    AdaGrad configuration = new AdaGrad();
    GradientUpdater<?> oldFunction = new AdaGradUpdater(configuration);
    int length = (int) (shape[0] * configuration.stateSize(shape[1]));
    INDArray view = Nd4j.zeros(length);
    oldFunction.setStateViewArray(view, shape, 'c', true);
    return oldFunction;
}
 
Example #21
Source File: AdamLearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 5 votes vote down vote up
@Override
protected GradientUpdater<?> getOldFunction(long[] shape) {
    Adam configuration = new Adam();
    GradientUpdater<?> oldFunction = new AdamUpdater(configuration);
    int length = (int) (shape[0] * configuration.stateSize(shape[1]));
    INDArray view = Nd4j.zeros(length);
    oldFunction.setStateViewArray(view, shape, 'c', true);
    return oldFunction;
}
 
Example #22
Source File: NoOp.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    if (viewArray != null) {
        throw new IllegalStateException("Cannot use view array with NoOp updater");
    }
    return new NoOpUpdater(this);
}
 
Example #23
Source File: AdaDelta.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
    AdaDeltaUpdater u = new AdaDeltaUpdater(this);
    long[] gradientShape = viewArray.shape();
    gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
    gradientShape[1] /= 2;
    u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
    return u;
}
 
Example #24
Source File: Nadam.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) {
    NadamUpdater u = new NadamUpdater(this);
    u.setState(updaterState, initializeStateArrays);
    return u;
}
 
Example #25
Source File: SgdLearnerTestCase.java    From jstarcraft-ai with Apache License 2.0 4 votes vote down vote up
@Override
protected GradientUpdater<?> getOldFunction(long[] shape) {
    Sgd configuration = new Sgd();
    GradientUpdater<?> oldFunction = new SgdUpdater(configuration);
    return oldFunction;
}
 
Example #26
Source File: Sgd.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) {
    SgdUpdater u = new SgdUpdater(this);
    u.setState(updaterState, initializeStateArrays);
    return u;
}
 
Example #27
Source File: AMSGrad.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) {
    AMSGradUpdater u = new AMSGradUpdater(this);
    u.setState(updaterState, initializeStateArrays);
    return u;
}
 
Example #28
Source File: UpdaterBlock.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public GradientUpdater getGradientUpdater() {
    if (gradientUpdater == null) {
        init();
    }
    return gradientUpdater;
}
 
Example #29
Source File: RmsProp.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) {
    RmsPropUpdater u = new RmsPropUpdater(this);
    u.setState(updaterState, initializeStateArrays);
    return u;
}
 
Example #30
Source File: CustomIUpdater.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Override
public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) {
    throw new UnsupportedOperationException();
}