Java Code Examples for org.tensorflow.framework.NodeDef#getInput()
The following examples show how to use
org.tensorflow.framework.NodeDef#getInput() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TensorArrayV3.java From nd4j with Apache License 2.0 | 6 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val idd = nodeDef.getInput(nodeDef.getInputCount() - 1); NodeDef iddNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(idd)) { iddNode = graph.getNode(i); } } val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph); if (arr != null) { int idx = arr.getInt(0); addIArgument(idx); } }
Example 2
Source File: StridedSlice.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val inputBegin = nodeDef.getInput(1); val inputEnd = nodeDef.getInput(2); val inputStrides = nodeDef.getInput(3); // bit masks for this slice val bm = nodeDef.getAttrOrThrow("begin_mask"); val xm = nodeDef.getAttrOrThrow("ellipsis_mask"); val em = nodeDef.getAttrOrThrow("end_mask"); val nm = nodeDef.getAttrOrThrow("new_axis_mask"); val sm = nodeDef.getAttrOrThrow("shrink_axis_mask"); beginMask = (int)bm.getI(); ellipsisMask = (int) xm.getI(); endMask = (int) em.getI(); newAxisMask = (int) nm.getI(); shrinkAxisMask = (int) sm.getI(); addIArgument(beginMask); addIArgument(ellipsisMask); addIArgument(endMask); addIArgument(newAxisMask); addIArgument(shrinkAxisMask); }
Example 3
Source File: TensorArray.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val idd = nodeDef.getInput(nodeDef.getInputCount() - 1); NodeDef iddNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(idd)) { iddNode = graph.getNode(i); } } val arr = TFGraphMapper.getNDArrayFromTensor(iddNode); if (arr != null) { int idx = arr.getInt(0); addIArgument(idx); } this.tensorArrayDataType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType()); }
Example 4
Source File: Pow.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val weightsName = nodeDef.getInput(1); val variable = initWith.getVariable(weightsName); val tmp = initWith.getArrForVarName(weightsName); // if second argument is scalar - we should provide array of same shape if (tmp != null) { if (tmp.isScalar()) { this.pow = tmp.getDouble(0); } } }
Example 5
Source File: BaseTensorOp.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val inputOne = nodeDef.getInput(1); val varFor = initWith.getVariable(inputOne); val nodeWithIndex = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,inputOne); val var = TFGraphMapper.getInstance().getArrayFrom(nodeWithIndex,graph); if(var != null) { val idx = var.getInt(0); addIArgument(idx); } }
Example 6
Source File: BaseTensorOp.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val inputOne = nodeDef.getInput(1); val varFor = initWith.getVariable(inputOne); val nodeWithIndex = TFGraphMapper.getNodeWithNameFromGraph(graph,inputOne); val var = TFGraphMapper.getArrayFrom(nodeWithIndex,graph); if(var != null) { val idx = var.getInt(0); addIArgument(idx); } }
Example 7
Source File: SavedModel.java From jpmml-tensorflow with GNU Affero General Public License v3.0 | 4 votes |
private void initializeTables(){ Collection<String> tableInitializerNames = Collections.emptyList(); try { CollectionDef collectionDef = getCollectionDef("table_initializer"); CollectionDef.NodeList nodeList = collectionDef.getNodeList(); tableInitializerNames = nodeList.getValueList(); } catch(IllegalArgumentException iae){ // Ignored } for(String tableInitializerName : tableInitializerNames){ NodeDef tableInitializer = getNodeDef(tableInitializerName); String name = tableInitializer.getInput(0); List<?> keys; List<?> values; try(Tensor tensor = run(tableInitializer.getInput(1))){ keys = TensorUtil.getValues(tensor); } // End try try(Tensor tensor = run(tableInitializer.getInput(2))){ values = TensorUtil.getValues(tensor); } Map<Object, Object> table = new LinkedHashMap<>(); if(keys.size() != values.size()){ throw new IllegalArgumentException(); } for(int i = 0; i < keys.size(); i++){ table.put(keys.get(i), values.get(i)); } putTable(name, table); } }
Example 8
Source File: Concat.java From nd4j with Apache License 2.0 | 4 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { int concatDimension = -1; String input = null; for(int i = 0; i < nodeDef.getInputCount(); i++) { if(nodeDef.getInput(i).contains("/concat_dim")) { input = nodeDef.getInput(i); break; } } //older versions may specify a concat_dim, usually it's the last argument if(input == null) { input = nodeDef.getInput(nodeDef.getInputCount() - 1); } val variable = initWith.getVariable(input); // concat dimension is only possible if (variable != null && variable.getArr() == null) { sameDiff.addPropertyToResolve(this, input); } else if (variable != null) { val arr = variable.getArr(); if (arr.length() == 1) { concatDimension = arr.getInt(0); } this.concatDimension = concatDimension; addIArgument(this.concatDimension); log.debug("Concat dimension: {}", concatDimension); } //don't pass both iArg and last axis down to libnd4j if(inputArguments().length == nodeDef.getInputCount()) { val inputArgs = inputArguments(); removeInputArgument(inputArgs[inputArguments().length - 1]); } sameDiff.removeArgFromFunction(input,this); }
Example 9
Source File: StridedSlice.java From nd4j with Apache License 2.0 | 4 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val inputBegin = nodeDef.getInput(1); val inputEnd = nodeDef.getInput(2); val inputStrides = nodeDef.getInput(3); NodeDef beginNode = null; NodeDef endNode = null; NodeDef strides = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(inputBegin)) { beginNode = graph.getNode(i); } if(graph.getNode(i).getName().equals(inputEnd)) { endNode = graph.getNode(i); } if(graph.getNode(i).getName().equals(inputStrides)) { strides = graph.getNode(i); } } // bit masks for this slice val bm = nodeDef.getAttrOrThrow("begin_mask"); val xm = nodeDef.getAttrOrThrow("ellipsis_mask"); val em = nodeDef.getAttrOrThrow("end_mask"); val nm = nodeDef.getAttrOrThrow("new_axis_mask"); val sm = nodeDef.getAttrOrThrow("shrink_axis_mask"); addIArgument((int) bm.getI()); addIArgument((int) xm.getI()); addIArgument((int) em.getI()); addIArgument((int) nm.getI()); addIArgument((int) sm.getI()); val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",beginNode,graph); val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",endNode,graph); val stridesArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",strides,graph); if (beginArr != null && endArr != null && stridesArr != null) { for (int e = 0; e < beginArr.length(); e++) addIArgument(beginArr.getInt(e)); for (int e = 0; e < endArr.length(); e++) addIArgument(endArr.getInt(e)); for (int e = 0; e < stridesArr.length(); e++) addIArgument(stridesArr.getInt(e)); } else { // do nothing } }
Example 10
Source File: Slice.java From nd4j with Apache License 2.0 | 4 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { /* strided slice typically takes 4 tensor arguments: 0) input, it's shape determines number of elements in other arguments 1) begin indices 2) end indices 3) strides */ val inputBegin = nodeDef.getInput(1); val inputEnd = nodeDef.getInput(2); NodeDef beginNode = null; NodeDef endNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(inputBegin)) { beginNode = graph.getNode(i); } if(graph.getNode(i).getName().equals(inputEnd)) { endNode = graph.getNode(i); } } val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",beginNode,graph); val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",endNode,graph); if (beginArr != null && endArr != null) { for (int e = 0; e < beginArr.length(); e++) addIArgument(beginArr.getInt(e)); for (int e = 0; e < endArr.length(); e++) addIArgument(endArr.getInt(e)); } else { // do nothing } }