org.nd4j.linalg.learning.config.AMSGrad Java Examples

The following examples show how to use org.nd4j.linalg.learning.config.AMSGrad. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: AMSGradUpdater.java    From nd4j with Apache License 2.0 4 votes vote down vote up
public AMSGradUpdater(AMSGrad config) {
    this.config = config;
}
 
Example #2
Source File: AMSGradUpdater.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public AMSGradUpdater(AMSGrad config) {
    this.config = config;
}
 
Example #3
Source File: GradientSharingTrainingTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test @Ignore
public void testEpochUpdating() throws Exception {
    //Ensure that epoch counter is incremented properly on the workers

    File temp = testDir.newFolder();

    //TODO this probably won't work everywhere...
    String controller = Inet4Address.getLocalHost().getHostAddress();
    String networkMask = controller.substring(0, controller.lastIndexOf('.')) + ".0" + "/16";

    VoidConfiguration voidConfiguration = VoidConfiguration.builder()
            .unicastPort(40123) // Should be open for IN/OUT communications on all Spark nodes
            .networkMask(networkMask) // Local network mask
            .controllerAddress(controller)
            .meshBuildMode(MeshBuildMode.PLAIN) // everyone is connected to the master
            .build();
    SharedTrainingMaster tm = new SharedTrainingMaster.Builder(voidConfiguration, 2, new AdaptiveThresholdAlgorithm(1e-3), 16)
            .rngSeed(12345)
            .collectTrainingStats(false)
            .batchSizePerWorker(16) // Minibatch size for each worker
            .workersPerNode(2) // Workers per node
            .exportDirectory("file:///" + temp.getAbsolutePath().replaceAll("\\\\", "/"))
            .build();


    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .updater(new AMSGrad(0.001))
            .graphBuilder()
            .addInputs("in")
            .layer("out", new OutputLayer.Builder().nIn(784).nOut(10).activation(Activation.SOFTMAX)
                    .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in")
            .setOutputs("out")
            .build();


    SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
    sparkNet.setListeners(new TestListener());

    DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
    int count = 0;
    List<String> paths = new ArrayList<>();
    List<DataSet> ds = new ArrayList<>();
    File f = testDir.newFolder();
    while (iter.hasNext() && count++ < 8) {
        DataSet d = iter.next();
        File out = new File(f, count + ".bin");
        d.save(out);
        String path = "file:///" + out.getAbsolutePath().replaceAll("\\\\", "/");
        paths.add(path);
        ds.add(d);
    }

    JavaRDD<String> pathRdd = sc.parallelize(paths);
    for( int i=0; i<3; i++ ) {
        ThresholdAlgorithm ta = tm.getThresholdAlgorithm();
        sparkNet.fitPaths(pathRdd);
        //Check also that threshold algorithm was updated/averaged
        ThresholdAlgorithm taAfter = tm.getThresholdAlgorithm();
        assertTrue("Threshold algorithm should have been updated with different instance after averaging", ta != taAfter);
        AdaptiveThresholdAlgorithm ataAfter = (AdaptiveThresholdAlgorithm) taAfter;
        assertFalse(Double.isNaN(ataAfter.getLastSparsity()));
        assertFalse(Double.isNaN(ataAfter.getLastThreshold()));
    }

    Set<Integer> expectedEpochs = new HashSet<>(Arrays.asList(0, 1, 2));
    assertEquals(expectedEpochs, TestListener.epochs);
}