Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#outputSingle()
The following examples show how to use
org.nd4j.autodiff.samediff.SameDiff#outputSingle() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestSameDiffRemoteModel.java From konduit-serving with Apache License 2.0 | 6 votes |
@Test public void testRemote(){ String filename = "tests/samediff_model.fb"; SameDiff sd = TestSameDiffServing.getModel(); new File(httpDir, "tests").mkdirs(); File f = new File(httpDir, filename); sd.save(f, true); String uri = uriFor(filename); INDArray inArr = Nd4j.rand(DataType.FLOAT, 3, 784); INDArray outExp = sd.outputSingle(Collections.singletonMap("in", inArr), "out"); Pipeline p = SequencePipeline.builder() .add(SameDiffStep.builder() .modelUri(uri) .outputNames(Collections.singletonList("out")) .build()) .build(); PipelineExecutor exec = p.executor(); Data d = Data.singleton("in", NDArray.create(inArr)); Data dOut = exec.exec(d); INDArray outArr = dOut.getNDArray("out").getAs(INDArray.class); assertEquals(outExp, outArr); }
Example 2
Source File: TestSameDiffServing.java From konduit-serving with Apache License 2.0 | 5 votes |
@Test public void testSameDiff() throws Exception { Nd4j.getRandom().setSeed(12345); INDArray inArr = Nd4j.rand(DataType.FLOAT, 3, 784); SameDiff sd = getModel(); INDArray outExp = sd.outputSingle(Collections.singletonMap("in", inArr), "out"); File dir = testDir.newFolder(); File f = new File(dir, "samediff.bin"); sd.save(f, false); Pipeline p = SequencePipeline.builder() .add(SameDiffStep.builder() .modelUri(f.toURI().toString()) .outputNames(Collections.singletonList("out")) .build()) .build(); PipelineExecutor exec = p.executor(); Data d = Data.singleton("in", NDArray.create(inArr)); Data dOut = exec.exec(d); INDArray outArr = dOut.getNDArray("out").getAs(INDArray.class); assertEquals(outExp, outArr); String json = p.toJson(); Pipeline p2 = Pipeline.fromJson(json); Data dOut2 = p2.executor().exec(d); INDArray outArr2 = dOut2.getNDArray("out").getAs(INDArray.class); assertEquals(outExp, outArr2); }
Example 3
Source File: UIListenerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testUIListenerBasic() throws Exception { Nd4j.getRandom().setSeed(12345); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); SameDiff sd = getSimpleNet(); File dir = testDir.newFolder(); File f = new File(dir, "logFile.bin"); UIListener l = UIListener.builder(f) .plotLosses(1) .trainEvaluationMetrics("softmax", 0, Evaluation.Metric.ACCURACY, Evaluation.Metric.F1) .updateRatios(1) .build(); sd.setListeners(l); sd.setTrainingConfig(TrainingConfig.builder() .dataSetFeatureMapping("in") .dataSetLabelMapping("label") .updater(new Adam(1e-1)) .weightDecay(1e-3, true) .build()); sd.fit(iter, 20); //Test inference after training with UI Listener still around Map<String, INDArray> m = new HashMap<>(); iter.reset(); m.put("in", iter.next().getFeatures()); INDArray out = sd.outputSingle(m, "softmax"); assertNotNull(out); assertArrayEquals(new long[]{150, 3}, out.shape()); }
Example 4
Source File: ProfilingListenerTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testProfilingListenerSimple() throws Exception { SameDiff sd = SameDiff.create(); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2); SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2)); SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2)); SDVariable sm = sd.nn.softmax("predictions", in.mmul("matmul", w).add("addbias", b)); SDVariable loss = sd.loss.logLoss("loss", label, sm); INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3); INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2); File dir = testDir.newFolder(); File f = new File(dir, "test.json"); ProfilingListener listener = ProfilingListener.builder(f) .recordAll() .warmup(5) .build(); sd.setListeners(listener); Map<String,INDArray> ph = new HashMap<>(); ph.put("in", i); for( int x=0; x<10; x++ ) { sd.outputSingle(ph, "predictions"); } String content = FileUtils.readFileToString(f, StandardCharsets.UTF_8); // System.out.println(content); assertFalse(content.isEmpty()); //Should be 2 begins and 2 ends for each entry //5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name String[] opNames = {"mmul", "add", "softmax"}; for(String s : opNames){ assertEquals(s, 10, StringUtils.countMatches(content, s)); } System.out.println("///////////////////////////////////////////"); ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.SAMEDIFF); }
Example 5
Source File: TensorFlowImportTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testInferShape() throws IOException { /** * node { name: "input" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { dim { size: -1 } dim { size: 4 } } } } } node { name: "bias" op: "Const" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { dim { size: 4 } } tensor_content: "\000\000\200?\000\000\000@\000\000@@\000\000\200@" } } } } node { name: "bias/read" op: "Identity" input: "bias" attr { key: "_class" value { list { s: "loc:@bias" } } } attr { key: "T" value { type: DT_FLOAT } } } node { name: "output" op: "BiasAdd" input: "input" input: "bias/read" attr { key: "data_format" value { s: "NHWC" } } attr { key: "T" value { type: DT_FLOAT } } } library { } */ SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/bias_add/frozen_model.pb").getInputStream()); assertNotNull(graph); INDArray input = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4); INDArray expectedOutput = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4).addRowVector(Nd4j.linspace(1,4,4, DataType.FLOAT)); INDArray actual = graph.outputSingle(Collections.singletonMap("input",input), graph.outputs().get(0)); assertEquals(input,graph.getVariable("input").getArr()); assertEquals(expectedOutput,actual); }