org.nd4j.linalg.learning.GradientUpdater Java Examples
The following examples show how to use
org.nd4j.linalg.learning.GradientUpdater.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: LearnerTestCase.java From jstarcraft-ai with Apache License 2.0 | 6 votes |
@Test public void testGradient() throws Exception { EnvironmentContext context = EnvironmentFactory.getContext(); Future<?> task = context.doTask(() -> { long[] shape = { 5L, 2L }; INDArray array = Nd4j.linspace(-2.5D, 2.0D, 10).reshape(shape); GradientUpdater<?> oldFunction = getOldFunction(shape); DenseMatrix gradient = getMatrix(array); Map<String, MathMatrix> gradients = new HashMap<>(); gradients.put("gradients", gradient); Learner newFuction = getNewFunction(shape); newFuction.doCache(gradients); for (int iteration = 0; iteration < 10; iteration++) { oldFunction.applyUpdater(array, iteration, 0); newFuction.learn(gradients, iteration, 0); System.out.println(array); System.out.println(gradients); Assert.assertTrue(equalMatrix(gradient, array)); } }); task.get(); }
Example #2
Source File: AdaDeltaLearnerTestCase.java From jstarcraft-ai with Apache License 2.0 | 5 votes |
@Override protected GradientUpdater<?> getOldFunction(long[] shape) { AdaDelta configuration = new AdaDelta(); GradientUpdater<?> oldFunction = new AdaDeltaUpdater(configuration); int length = (int) (shape[0] * configuration.stateSize(shape[1])); INDArray view = Nd4j.zeros(length); oldFunction.setStateViewArray(view, shape, 'c', true); return oldFunction; }
Example #3
Source File: AdaMax.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { AdaMaxUpdater a = new AdaMaxUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 2; a.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return a; }
Example #4
Source File: AdaDelta.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { AdaDeltaUpdater u = new AdaDeltaUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 2; u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return u; }
Example #5
Source File: Sgd.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { if (viewArray != null) { throw new IllegalStateException("View arrays are not supported/required for SGD updater"); } return new SgdUpdater(this); }
Example #6
Source File: Adam.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { AdamUpdater u = new AdamUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 2; u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return u; }
Example #7
Source File: NoOp.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { if (viewArray != null) { throw new IllegalStateException("Cannot use view array with NoOp updater"); } return new NoOpUpdater(this); }
Example #8
Source File: Nadam.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { NadamUpdater u = new NadamUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 2; u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return u; }
Example #9
Source File: AMSGrad.java From nd4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { AMSGradUpdater u = new AMSGradUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 3; u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return u; }
Example #10
Source File: Nadam.java From nd4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { NadamUpdater u = new NadamUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 2; u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return u; }
Example #11
Source File: Sgd.java From nd4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { if (viewArray != null) { throw new IllegalStateException("View arrays are not supported/required for SGD updater"); } return new SgdUpdater(this); }
Example #12
Source File: AMSGrad.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { AMSGradUpdater u = new AMSGradUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 3; u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return u; }
Example #13
Source File: AdaMax.java From nd4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { AdaMaxUpdater a = new AdaMaxUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 2; a.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return a; }
Example #14
Source File: Adam.java From nd4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { AdamUpdater u = new AdamUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 2; u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return u; }
Example #15
Source File: NadamLearnerTestCase.java From jstarcraft-ai with Apache License 2.0 | 5 votes |
@Override protected GradientUpdater<?> getOldFunction(long[] shape) { Nadam configuration = new Nadam(); GradientUpdater<?> oldFunction = new NadamUpdater(configuration); int length = (int) (shape[0] * configuration.stateSize(shape[1])); INDArray view = Nd4j.zeros(length); oldFunction.setStateViewArray(view, shape, 'c', true); return oldFunction; }
Example #16
Source File: AdaMaxLearnerTestCase.java From jstarcraft-ai with Apache License 2.0 | 5 votes |
@Override protected GradientUpdater<?> getOldFunction(long[] shape) { AdaMax configuration = new AdaMax(); GradientUpdater<?> oldFunction = new AdaMaxUpdater(configuration); int length = (int) (shape[0] * configuration.stateSize(shape[1])); INDArray view = Nd4j.zeros(length); oldFunction.setStateViewArray(view, shape, 'c', true); return oldFunction; }
Example #17
Source File: NesterovLearnerTestCase.java From jstarcraft-ai with Apache License 2.0 | 5 votes |
@Override protected GradientUpdater<?> getOldFunction(long[] shape) { Nesterovs configuration = new Nesterovs(); GradientUpdater<?> oldFunction = new NesterovsUpdater(configuration); int length = (int) (shape[0] * configuration.stateSize(shape[1])); INDArray view = Nd4j.zeros(length); oldFunction.setStateViewArray(view, shape, 'c', true); return oldFunction; }
Example #18
Source File: RmsPropLearnerTestCase.java From jstarcraft-ai with Apache License 2.0 | 5 votes |
@Override protected GradientUpdater<?> getOldFunction(long[] shape) { RmsProp configuration = new RmsProp(); GradientUpdater<?> oldFunction = new RmsPropUpdater(configuration); int length = (int) (shape[0] * configuration.stateSize(shape[1])); INDArray view = Nd4j.zeros(length); oldFunction.setStateViewArray(view, shape, 'c', true); return oldFunction; }
Example #19
Source File: CustomIUpdater.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { if (viewArray != null) { throw new IllegalStateException("View arrays are not supported/required for SGD updater"); } return new CustomGradientUpdater(this); }
Example #20
Source File: AdaGradLearnerTestCase.java From jstarcraft-ai with Apache License 2.0 | 5 votes |
@Override protected GradientUpdater<?> getOldFunction(long[] shape) { AdaGrad configuration = new AdaGrad(); GradientUpdater<?> oldFunction = new AdaGradUpdater(configuration); int length = (int) (shape[0] * configuration.stateSize(shape[1])); INDArray view = Nd4j.zeros(length); oldFunction.setStateViewArray(view, shape, 'c', true); return oldFunction; }
Example #21
Source File: AdamLearnerTestCase.java From jstarcraft-ai with Apache License 2.0 | 5 votes |
@Override protected GradientUpdater<?> getOldFunction(long[] shape) { Adam configuration = new Adam(); GradientUpdater<?> oldFunction = new AdamUpdater(configuration); int length = (int) (shape[0] * configuration.stateSize(shape[1])); INDArray view = Nd4j.zeros(length); oldFunction.setStateViewArray(view, shape, 'c', true); return oldFunction; }
Example #22
Source File: NoOp.java From nd4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { if (viewArray != null) { throw new IllegalStateException("Cannot use view array with NoOp updater"); } return new NoOpUpdater(this); }
Example #23
Source File: AdaDelta.java From nd4j with Apache License 2.0 | 5 votes |
@Override public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { AdaDeltaUpdater u = new AdaDeltaUpdater(this); long[] gradientShape = viewArray.shape(); gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); gradientShape[1] /= 2; u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); return u; }
Example #24
Source File: Nadam.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) { NadamUpdater u = new NadamUpdater(this); u.setState(updaterState, initializeStateArrays); return u; }
Example #25
Source File: SgdLearnerTestCase.java From jstarcraft-ai with Apache License 2.0 | 4 votes |
@Override protected GradientUpdater<?> getOldFunction(long[] shape) { Sgd configuration = new Sgd(); GradientUpdater<?> oldFunction = new SgdUpdater(configuration); return oldFunction; }
Example #26
Source File: Sgd.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) { SgdUpdater u = new SgdUpdater(this); u.setState(updaterState, initializeStateArrays); return u; }
Example #27
Source File: AMSGrad.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) { AMSGradUpdater u = new AMSGradUpdater(this); u.setState(updaterState, initializeStateArrays); return u; }
Example #28
Source File: UpdaterBlock.java From deeplearning4j with Apache License 2.0 | 4 votes |
public GradientUpdater getGradientUpdater() { if (gradientUpdater == null) { init(); } return gradientUpdater; }
Example #29
Source File: RmsProp.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) { RmsPropUpdater u = new RmsPropUpdater(this); u.setState(updaterState, initializeStateArrays); return u; }
Example #30
Source File: CustomIUpdater.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) { throw new UnsupportedOperationException(); }