Java Code Examples for org.nd4j.imports.graphmapper.tf.TFGraphMapper#importGraph()

The following examples show how to use org.nd4j.imports.graphmapper.tf.TFGraphMapper#importGraph() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileDualMapping1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-4.0);
    val input1 = Nd4j.scalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    INDArray array = tg.outputAll(null).get(tg.outputs().get(0));
    val exp = Nd4j.create(2, 2).assign(-1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example 2
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileMapping1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input = Nd4j.create(2, 2).assign(1);
    tg.associateArrayWithVariable(input, tg.getVariable("input_0"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_3.fb"));

    //log.info("{}", tg.asFlatPrint());


    val array = tg.outputAll(null).get(tg.outputs().get(0));
    val exp = Nd4j.create(2, 2).assign(1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example 3
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    @Ignore
    public void testIntermediateTensorArraySimple1() throws Exception {
        Nd4j.create(1);
        val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream());
        tg.setArrayForVariable("input_matrix",Nd4j.ones(3,2));

        assertNotNull(tg);

        val firstSlice = tg.getVariable("strided_slice");


        val fb = tg.asFlatBuffers(true);
        assertNotNull(fb);

        val graph = FlatGraph.getRootAsFlatGraph(fb);
        assertEquals(36, graph.variablesLength());

        assertTrue(graph.nodesLength() > 1);
     /*   assertEquals("strided_slice", graph.nodes(0).name());
        assertEquals("TensorArray", graph.nodes(1).name());
*/
        //   assertEquals(4, graph.nodes(0).inputPairedLength());

        //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/tensor_array.fb"));
    }
 
Example 4
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testWhileDualMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-9.0);
    val input1 = Nd4j.scalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    val array = tg.outputAll(null).get(tg.outputs().get(0));
    val exp = Nd4j.create(2, 2).assign(-3);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example 5
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testMixedWhileCond1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(1.0);
    val input1 = Nd4j.create(3, 3).assign(2.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_nested.fb"));


    //log.info("{}", tg.asFlatPrint());

    Map<String,INDArray> m = tg.outputAll(null);
    val array = m.get(tg.outputs().get(0));
    //val array = tg.getVariable("output").getArr();
    val exp = Nd4j.create(2, 2).assign(15.0);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
Example 6
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testLenet() throws Exception {
    /**
     * Produced with:
     * python  ~/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py  --input_graph=graph2.pb.txt  --input_checkpoint=test3.ckpt  --output_graph=graph_frozen2.pb  --output_node_name=output/BiasAdd --input_binary=False

     */

    Nd4j.create(1);
    val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
    val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList());
    System.out.println(nodeNames);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());


    val convNode = tg.getVariable("conv2d/kernel");
    assertNotNull(convNode.getArr());
    val shape = convNode.getShape();
    System.out.println(Arrays.toString(shape));

    // this is NHWC weights. will be changed soon.
    assertArrayEquals(new long[]{5,5,1,32}, shape);
    System.out.println(convNode);
}
 
Example 7
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testControlDependencies1() throws Exception {
    SameDiff sd = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/cond/cond_true/frozen_model.pb").getInputStream());



    /*
    Control dependencies:
    variables:
        - cond/LinSpace/start - depends on cond/switch_t
        - cond/LinSpace/stop - depends on cond/switch_t
        - cond/LinSpace/num - depends on cond/switch_t
        - cond/ones - depends on cond/switch_f
     */

    Map<String,Variable> variables = sd.getVariables();

    assertEquals(variables.get("cond/LinSpace/start").getControlDeps(), Collections.singletonList("cond/switch_t"));
    assertEquals(variables.get("cond/LinSpace/stop"), Collections.singletonList("cond/switch_t"));
    assertEquals(variables.get("cond/LinSpace/num"), Collections.singletonList("cond/switch_t"));
    assertEquals(variables.get("cond/ones"), Collections.singletonList("cond/switch_f"));
}
 
Example 8
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
@Ignore
public void importGraph1() throws Exception {
    SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream());

    assertNotNull(graph);

    assertEquals(2, graph.variableMap().size());

    SDVariable var0 = graph.variableMap().get("zeros");
    SDVariable var1 = graph.variableMap().get("ones");

    assertNotNull(var0);
    assertNotNull(var1);

    assertNotNull(var0.getArr());
    assertNotNull(var1.getArr());

    assertEquals(0.0, var0.getArr().sumNumber().doubleValue(), 1e-5);
    assertEquals(12.0, var1.getArr().sumNumber().doubleValue(), 1e-5);
}
 
Example 9
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void testImportMapping1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream());

    val variables = new HashMap<String, SDVariable>();
    for (val var : tg.variables()) {
        variables.put(var.name(), var);
    }

    val functions = new HashMap<String, DifferentialFunction>();
    for (val func: tg.ops()) {
        val ownName = func.getOwnName();
        String outName = func.outputVariables()[0].name();

        assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName));
        assertEquals(ownName, outName);
    }
}
 
Example 10
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
    public void testCondMapping1() throws Exception {
        Nd4j.create(1);
        val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
        assertNotNull(tg);

        tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0_1.fb"));
/*
        //log.info("{}", tg.asFlatPrint());
        val array = tg.execAndEndResult();
        val exp = Nd4j.create(2, 2).assign(-2);
        assertNotNull(array);
        assertEquals(exp, array);*/
    }
 
Example 11
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testImportIris() throws Exception  {
    SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/train_iris.pb").getInputStream());
    assertNotNull(graph);

}
 
Example 12
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void importGraph3() throws Exception {
    SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream());

    assertNotNull(graph);
}
 
Example 13
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
//@Ignore
public void testCrash_119_reduce_dim_true() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/reduce_dim_true.pb.txt").getInputStream());
    assertNotNull(tg);

    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/reduce_dim_true.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(), true);
}
 
Example 14
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testSingleExample_1() {
    val g = TFGraphMapper.importGraph(new File("C:\\Users\\raver\\Downloads\\mnist.pb"));

    val array = Nd4j.ones(1, 28, 28);
    g.associateArrayWithVariable(array, "flatten_1_input");

    //g.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mnist.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).build());

    g.outputAll(null);
}
 
Example 15
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testCrash_119_expand_dim() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(new double[] {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743}, new int[] {3, 4});

    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));

    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/expand_dim.fb"));
}
 
Example 16
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBoolImport_1() throws Exception {
    Nd4j.create(1);
    for (int e = 0; e < 1000; e++){
        val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream());

        Map<String,INDArray> result = tg.output(Collections.emptyMap(), tg.outputs());

        assertNotNull(result);
        assertTrue(result.size() > 0);
    }
}
 
Example 17
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testTensorArray_119_4() throws Exception {
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream());
    assertNotNull(tg);

    val input_matrix = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2);
    log.info("Graph: {}", tg.asFlatPrint());
    val array = tg.outputSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));

    val exp = Nd4j.create(new float[] {3,6,  9,12,  15,18,  21,24,  27,30}, new int[]{5, 2});

    assertEquals(exp, array);
}
 
Example 18
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test(expected = ND4JIllegalStateException.class)
public void testNonFrozenGraph1() throws Exception {
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/unfrozen_simple_ae.pb").getInputStream());
}
 
Example 19
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
@Ignore
public void importGraph4() throws Exception {
    SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_multiply.pb.txt").getInputStream());

    assertNotNull(graph);

    val p0 = Nd4j.create(10, 10).assign(2.0);
    val p1 = Nd4j.create(10, 10).assign(3.0);

    graph.associateArrayWithVariable(p0,graph.variableMap().get("Placeholder"));
    graph.associateArrayWithVariable(p1, graph.variableMap().get("Placeholder_1"));


    graph.var("Placeholder", p0);
    graph.var("Placeholder_1", p1);

    val res = graph.outputAll(null).get(graph.outputs().get(0));



    assertEquals(6.0, res.meanNumber().doubleValue(), 1e-5);
}
 
Example 20
Source File: TensorFlowImportTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    @Ignore
    public void testIntermediateTensorArrayLoop1() throws Exception {
        val input = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2);
        val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream());
        tg.setArrayForVariable("input_matrix",input);
        assertNotNull(tg);

        val fb = tg.asFlatBuffers(true);
        assertNotNull(fb);

        val graph = FlatGraph.getRootAsFlatGraph(fb);
        assertEquals(12, graph.variablesLength());

        val strided_slice = graph.nodes(0);

      /*  assertEquals("strided_slice", strided_slice.name());
        assertEquals("TensorArray", graph.nodes(1).name());
*/
        assertEquals(4, strided_slice.inputPairedLength());


        // we expect these inputs to be 1:0, 2:0, 3:0 and 4:0 respectively
        // where 1 (or 2/3/4) is a graph node id
        // and :0 is graph node output index, which is 0 because that's predefined variables
        val in0 = strided_slice.inputPaired(0);
        val in1 = strided_slice.inputPaired(1);
        val in2 = strided_slice.inputPaired(2);
        val in3 = strided_slice.inputPaired(3);

        assertEquals(2, in0.first());
        assertEquals(0, in0.second());

        assertEquals(3, in1.first());
        assertEquals(0, in1.second());

        assertEquals(4, in2.first());
        assertEquals(0, in2.second());

        assertEquals(5, in3.first());
        assertEquals(0, in3.second());
    }