Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#setListeners()
The following examples show how to use
org.nd4j.autodiff.samediff.SameDiff#setListeners() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: CheckpointListenerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testCheckpointEveryEpoch() throws Exception { File dir = testDir.newFolder(); SameDiff sd = getModel(); CheckpointListener l = CheckpointListener.builder(dir) .saveEveryNEpochs(1) .build(); sd.setListeners(l); DataSetIterator iter = getIter(); sd.fit(iter, 3); File[] files = dir.listFiles(); String s1 = "checkpoint-0_epoch-0_iter-9"; //Note: epoch is 10 iterations, 0-9, 10-19, 20-29, etc String s2 = "checkpoint-1_epoch-1_iter-19"; String s3 = "checkpoint-2_epoch-2_iter-29"; boolean found1 = false; boolean found2 = false; boolean found3 = false; for(File f : files){ String s = f.getAbsolutePath(); if(s.contains(s1)) found1 = true; if(s.contains(s2)) found2 = true; if(s.contains(s3)) found3 = true; } assertEquals(4, files.length); //3 checkpoints and 1 text file (metadata) assertTrue(found1 && found2 && found3); }
Example 2
Source File: CheckpointListenerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testCheckpointEvery5Iter() throws Exception { File dir = testDir.newFolder(); SameDiff sd = getModel(); CheckpointListener l = CheckpointListener.builder(dir) .saveEveryNIterations(5) .build(); sd.setListeners(l); DataSetIterator iter = getIter(); sd.fit(iter, 2); //2 epochs = 20 iter File[] files = dir.listFiles(); List<String> names = Arrays.asList( "checkpoint-0_epoch-0_iter-4", "checkpoint-1_epoch-0_iter-9", "checkpoint-2_epoch-1_iter-14", "checkpoint-3_epoch-1_iter-19"); boolean[] found = new boolean[names.size()]; for(File f : files){ String s = f.getAbsolutePath(); // System.out.println(s); for( int i=0; i<names.size(); i++ ){ if(s.contains(names.get(i))){ found[i] = true; break; } } } assertEquals(5, files.length); //4 checkpoints and 1 text file (metadata) for( int i=0; i<found.length; i++ ){ assertTrue(names.get(i), found[i]); } }
Example 3
Source File: ExecDebuggingListenerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testExecDebugListener(){ SameDiff sd = SameDiff.create(); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2); SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2)); SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2)); SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b)); SDVariable loss = sd.loss.logLoss("loss", label, sm); INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3); INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2); sd.setTrainingConfig(TrainingConfig.builder() .dataSetFeatureMapping("in") .dataSetLabelMapping("label") .updater(new Adam(0.001)) .build()); for(ExecDebuggingListener.PrintMode pm : ExecDebuggingListener.PrintMode.values()){ sd.setListeners(new ExecDebuggingListener(pm, -1, true)); // sd.output(m, "softmax"); sd.fit(new DataSet(i, l)); System.out.println("\n\n\n"); } }
Example 4
Source File: UIListenerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testUIListenerBasic() throws Exception { Nd4j.getRandom().setSeed(12345); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); SameDiff sd = getSimpleNet(); File dir = testDir.newFolder(); File f = new File(dir, "logFile.bin"); UIListener l = UIListener.builder(f) .plotLosses(1) .trainEvaluationMetrics("softmax", 0, Evaluation.Metric.ACCURACY, Evaluation.Metric.F1) .updateRatios(1) .build(); sd.setListeners(l); sd.setTrainingConfig(TrainingConfig.builder() .dataSetFeatureMapping("in") .dataSetLabelMapping("label") .updater(new Adam(1e-1)) .weightDecay(1e-3, true) .build()); sd.fit(iter, 20); //Test inference after training with UI Listener still around Map<String, INDArray> m = new HashMap<>(); iter.reset(); m.put("in", iter.next().getFeatures()); INDArray out = sd.outputSingle(m, "softmax"); assertNotNull(out); assertArrayEquals(new long[]{150, 3}, out.shape()); }
Example 5
Source File: ProfilingListenerTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testProfilingListenerSimple() throws Exception { SameDiff sd = SameDiff.create(); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2); SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2)); SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2)); SDVariable sm = sd.nn.softmax("predictions", in.mmul("matmul", w).add("addbias", b)); SDVariable loss = sd.loss.logLoss("loss", label, sm); INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3); INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2); File dir = testDir.newFolder(); File f = new File(dir, "test.json"); ProfilingListener listener = ProfilingListener.builder(f) .recordAll() .warmup(5) .build(); sd.setListeners(listener); Map<String,INDArray> ph = new HashMap<>(); ph.put("in", i); for( int x=0; x<10; x++ ) { sd.outputSingle(ph, "predictions"); } String content = FileUtils.readFileToString(f, StandardCharsets.UTF_8); // System.out.println(content); assertFalse(content.isEmpty()); //Should be 2 begins and 2 ends for each entry //5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name String[] opNames = {"mmul", "add", "softmax"}; for(String s : opNames){ assertEquals(s, 10, StringUtils.countMatches(content, s)); } System.out.println("///////////////////////////////////////////"); ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.SAMEDIFF); }
Example 6
Source File: CheckpointListenerTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testCheckpointListenerEveryTimeUnit() throws Exception { File dir = testDir.newFolder(); SameDiff sd = getModel(); CheckpointListener l = new CheckpointListener.Builder(dir) .keepLast(2) .saveEvery(4, TimeUnit.SECONDS) .build(); sd.setListeners(l); DataSetIterator iter = getIter(15, 150); for(int i=0; i<5; i++ ){ //10 iterations total sd.fit(iter, 1); Thread.sleep(5000); } //Expect models saved at iterations: 10, 20, 30, 40 //But: keep only 30, 40 File[] files = dir.listFiles(); assertEquals(3, files.length); //2 files, 1 metadata file List<String> names = Arrays.asList( "checkpoint-2_epoch-3_iter-30", "checkpoint-3_epoch-4_iter-40"); boolean[] found = new boolean[names.size()]; for(File f : files){ String s = f.getAbsolutePath(); // System.out.println(s); for( int i=0; i<names.size(); i++ ){ if(s.contains(names.get(i))){ found[i] = true; break; } } } for( int i=0; i<found.length; i++ ){ assertTrue(names.get(i), found[i]); } }
Example 7
Source File: CheckpointListenerTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testCheckpointListenerKeepLast3AndEvery3() throws Exception { File dir = testDir.newFolder(); SameDiff sd = getModel(); CheckpointListener l = new CheckpointListener.Builder(dir) .keepLastAndEvery(3, 3) .saveEveryNEpochs(2) .fileNamePrefix("myFilePrefix") .build(); sd.setListeners(l); DataSetIterator iter = getIter(); sd.fit(iter, 20); //Expect models saved at end of epochs: 1, 3, 5, 7, 9, 11, 13, 15, 17, 19 //But: keep only 5, 11, 15, 17, 19 File[] files = dir.listFiles(); int count = 0; Set<Integer> cpNums = new HashSet<>(); Set<Integer> epochNums = new HashSet<>(); for(File f2 : files){ if(!f2.getPath().endsWith(".bin")){ continue; } count++; int idx = f2.getName().indexOf("epoch-"); int end = f2.getName().indexOf("_", idx); int num = Integer.parseInt(f2.getName().substring(idx + "epoch-".length(), end)); epochNums.add(num); int start = f2.getName().indexOf("checkpoint-"); end = f2.getName().indexOf("_", start + "checkpoint-".length()); int epochNum = Integer.parseInt(f2.getName().substring(start + "checkpoint-".length(), end)); cpNums.add(epochNum); } assertEquals(cpNums.toString(), 5, cpNums.size()); Assert.assertTrue(cpNums.toString(), cpNums.containsAll(Arrays.asList(2, 5, 7, 8, 9))); Assert.assertTrue(epochNums.toString(), epochNums.containsAll(Arrays.asList(5, 11, 15, 17, 19))); assertEquals(5, l.availableCheckpoints().size()); }
Example 8
Source File: ImportModelDebugger.java From deeplearning4j with Apache License 2.0 | 4 votes |
public static void main(String[] args) { File modelFile = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85\\tf_model.pb"); File rootDir = new File("C:\\Temp\\TF_Graphs\\cifar10_gan_85"); SameDiff sd = TFGraphMapper.importGraph(modelFile); ImportDebugListener l = ImportDebugListener.builder(rootDir) .checkShapesOnly(true) .floatingPointEps(1e-5) .onFailure(ImportDebugListener.OnFailure.EXCEPTION) .logPass(true) .build(); sd.setListeners(l); Map<String,INDArray> ph = loadPlaceholders(rootDir); List<String> outputs = sd.outputs(); sd.output(ph, outputs); }