Java Code Examples for org.deeplearning4j.nn.multilayer.MultiLayerNetwork#getUpdater()
The following examples show how to use
org.deeplearning4j.nn.multilayer.MultiLayerNetwork#getUpdater() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ParameterAveragingTrainingWorker.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) { INDArray updaterState = null; if (saveUpdater) { Updater u = network.getUpdater(); if (u != null) updaterState = u.getStateViewArray(); } Nd4j.getExecutioner().commit(); Collection<StorageMetaData> storageMetaData = null; Collection<Persistable> listenerStaticInfo = null; Collection<Persistable> listenerUpdates = null; if (listenerRouterProvider != null) { StatsStorageRouter r = listenerRouterProvider.getRouter(); if (r instanceof VanillaStatsStorageRouter) { //TODO this is ugly... need to find a better solution VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r; storageMetaData = ssr.getStorageMetaData(); listenerStaticInfo = ssr.getStaticInfo(); listenerUpdates = ssr.getUpdates(); } } return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates); }
Example 2
Source File: NetworkUtils.java From deeplearning4j with Apache License 2.0 | 5 votes |
private static void refreshUpdater(MultiLayerNetwork net) { INDArray origUpdaterState = net.getUpdater().getStateViewArray(); MultiLayerUpdater origUpdater = (MultiLayerUpdater) net.getUpdater(); net.setUpdater(null); MultiLayerUpdater newUpdater = (MultiLayerUpdater) net.getUpdater(); INDArray newUpdaterState = rebuildUpdaterStateArray(origUpdaterState, origUpdater.getUpdaterBlocks(), newUpdater.getUpdaterBlocks()); newUpdater.setStateViewArray(newUpdaterState); }
Example 3
Source File: TestUpdaters.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testDivisionByMinibatch1(){ //No batch norm - should be single INDArray equal to flattened gradient view MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list() .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.fit(Nd4j.create(1,10), Nd4j.create(1,10)); BaseMultiLayerUpdater u = (BaseMultiLayerUpdater) net.getUpdater(); List<INDArray> l = u.getGradientsForMinibatchDivision(); assertNotNull(l); assertEquals(1, l.size()); INDArray arr = l.get(0); assertEquals(3 * (10 * 10 + 10), arr.length()); assertEquals(net.getFlattenedGradients(), arr); }
Example 4
Source File: TestUpdaters.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testDivisionByMinibatch2(){ //With batch norm - should be multiple 'division by minibatch' array segments //i.e., exclude batch norm mean/variance MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list() .layer(new DenseLayer.Builder().nIn(10).nOut(9).build()) .layer(new BatchNormalization.Builder().nOut(9).build()) .layer(new DenseLayer.Builder().nIn(9).nOut(8).build()) .layer(new BatchNormalization.Builder().nOut(8).build()) .layer(new OutputLayer.Builder().nIn(8).nOut(7).activation(Activation.SOFTMAX).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.fit(Nd4j.create(1,10), Nd4j.create(1,7)); BaseMultiLayerUpdater u = (BaseMultiLayerUpdater) net.getUpdater(); List<INDArray> l = u.getGradientsForMinibatchDivision(); assertNotNull(l); assertEquals(3, l.size()); //3 segments //First subset: 0_W, 0_b, 1_gamma, 1_beta Size 10x9 + 9 + 2x9 //Then excluding 1_mean, 1_var //Second subset: 2_W, 2_b, 3_gamma, 3_beta Size 9x8 + 8 + 2x8 //Then excluding 3_mean, 3_var //Third subset: 4_W, 4_b Size 8x7 + 7 assertEquals(10*9 + 9 + 2*9, l.get(0).length()); assertEquals(9*8 + 8 + 2*8, l.get(1).length()); assertEquals(8*7 + 7, l.get(2).length()); INDArray view = ((BaseMultiLayerUpdater) net.getUpdater()).getFlattenedGradientsView(); view.assign(Nd4j.linspace(1, view.length(), view.length(), Nd4j.dataType())); INDArray expView1 = view.get(interval(0,0,true), interval(0, 10*9 + 9 + 2*9)); assertEquals(expView1, l.get(0)); long start2 = (10*9 + 9 + 2*9) + 2*9; long length2 = 9*8 + 8 + 2*8; INDArray expView2 = view.get(interval(0,0,true), interval(start2, start2 + length2)); assertEquals(expView2, l.get(1)); long start3 = start2 + length2 + 2*8; long length3 = 8*7 + 7; INDArray expView3 = view.get(interval(0,0,true), interval(start3, start3 + length3)); assertEquals(expView3, l.get(2)); }
Example 5
Source File: TestUpdaters.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testDivisionByMinibatch3() throws Exception{ //With batch norm - should be multiple 'division by minibatch' array segments //i.e., exclude batch norm mean/variance MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list() .layer(new BatchNormalization.Builder().nOut(6).build()) .layer(new ConvolutionLayer.Builder().nIn(6).nOut(5).kernelSize(2,2).build()) .layer(new BatchNormalization.Builder().nOut(5).build()) .layer(new ConvolutionLayer.Builder().nIn(5).nOut(4).kernelSize(2,2).build()) .layer(new BatchNormalization.Builder().nOut(4).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); BaseMultiLayerUpdater u = (BaseMultiLayerUpdater) net.getUpdater(); Method m = BaseMultiLayerUpdater.class.getDeclaredMethod("divideByMinibatch", boolean.class, Gradient.class, int.class); m.setAccessible(true); m.invoke(u, false, null, 32); List<INDArray> l = u.getGradientsForMinibatchDivision(); assertNotNull(l); assertEquals(3, l.size()); //3 segments //First subset: 0_gamma, 0_beta, 2x6 //Then excluding 0_mean, 0_var //Second subset: 1_b, 1_W, 2_gamma, 2_beta (6x5x2x2) + 5 + 2x5 //Then excluding 2_mean, 2_var //Third subset: 3_b, 3_W, 4_gamma, 4_beta (5*4*2*2) + 4 + 2*4 //Then excluding 4_mean, 4_beta assertEquals(2*6, l.get(0).length()); assertEquals(6*5*2*2 + 5 + 2*5, l.get(1).length()); assertEquals(5*4*2*2 + 4 + 2*4, l.get(2).length()); INDArray view = ((BaseMultiLayerUpdater) net.getUpdater()).getFlattenedGradientsView(); view.assign(Nd4j.linspace(1, view.length(), view.length(), Nd4j.dataType())); INDArray expView1 = view.get(interval(0,0,true), interval(0, 2*6)); assertEquals(expView1, l.get(0)); long start2 = 2*6 + 2*6; long length2 = 6*5*2*2 + 5 + 2*5; INDArray expView2 = view.get(interval(0,0,true), interval(start2, start2 + length2)); assertEquals(expView2, l.get(1)); long start3 = start2 + length2 + 2*5; long length3 = 5*4*2*2 + 4 + 2*4; INDArray expView3 = view.get(interval(0,0,true), interval(start3, start3 + length3)); assertEquals(expView3, l.get(2)); }