Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#getOutputsForOp()
The following examples show how to use
org.nd4j.autodiff.samediff.SameDiff#getOutputsForOp() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TensorMmul.java From deeplearning4j with Apache License 2.0 | 6 votes |
public TensorMmul(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, int[][] dimensions, MMulTranspose mMulTranspose) { super(null, sameDiff, new SDVariable[]{i_v1,i_v2}); this.sameDiff = sameDiff; this.mMulTranspose = mMulTranspose; this.axes = dimensions; if(!addedEdges && sameDiff.getOutputsForOp(this) == null) { addedEdges = true; } addIArgument(dimensions[0].length); addIArgument(dimensions[0]); addIArgument(dimensions[1].length); addIArgument(dimensions[1]); }
Example 2
Source File: TFGraphTestAllHelper.java From deeplearning4j with Apache License 2.0 | 4 votes |
public static void checkIntermediate(Map<String, INDArray> inputs, String modelName, String baseDir, String modelFileName, ExecuteWith execType, BiFunction<File,String,SameDiff> loader, Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir, boolean printArraysDebugging) throws IOException { Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" + " must be null or both must be provided"); Nd4j.EPS_THRESHOLD = 1e-3; OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order Pair<SameDiff, Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null, printArraysDebugging); SameDiff graph = p.getFirst(); Map<String,INDArray> sdPredictions = p.getSecond(); //Collect coverage info about ops OpValidation.collectTensorflowImportCoverage(graph); if (!execType.equals(ExecuteWith.JUST_PRINT)) { int count = 0; //Evaluate the nodes in their execution order - this is useful for debugging (as we want the *first* failure // to be detected before later failures) List<String> varNames = new ArrayList<>(); Map<String,SameDiffOp> fns = graph.getOps(); List<String> execOrder = listener.getOpNamesList(); for(String opName : execOrder){ String[] outputs = graph.getOutputsForOp(fns.get(opName).getOp()); Collections.addAll(varNames, outputs); } for (String varName : varNames) { if (!inputs.containsKey(varName)) { //avoiding placeholders INDArray tfValue = intermediateVars(modelName, baseDir, varName, localTestDir); if (tfValue == null) { continue; } log.info("Starting check: variable {}", varName); if (skipNode(modelName, varName)) { log.info("\n\tFORCING no check on " + varName); } else { assertArrayEquals("Shape not equal on node " + varName, tfValue.shape(), graph.getVariable(varName).getShape()); INDArray sdVal = sdPredictions.get(varName); if(maxRelErrorOverride != null){ INDArray diff = Transforms.abs(tfValue.sub(sdVal), false); INDArray absErrorMask = diff.gte(minAbsErrorOverride); //value 1 if x[i] > minAbsError; value 0 otherwise. Used to get rid of 1e-30 vs. 1e-29 type failures INDArray sumAbs = Transforms.abs(tfValue, true).addi(Transforms.abs(sdVal, true)); BooleanIndexing.replaceWhere(sumAbs, 1.0, Conditions.equals(0.0)); //Can only get 0.0 if both are zeros - need to avoid 0/0=NaN INDArray relError = diff.divi(sumAbs); relError.muli(absErrorMask); int countExceeds = Nd4j.getExecutioner().exec(new MatchCondition(relError, Conditions.greaterThan(maxRelErrorOverride))).getInt(0); double maxRE = -1; //Mainly used for analysis in debugger: DifferentialFunction op = null; String[] opInputs = null; if(countExceeds > 0){ maxRE = relError.maxNumber().doubleValue(); //Find the op that this variable is produced by op = graph.getVariableOutputOp(varName); opInputs = graph.getInputsForOp(op); } assertEquals( varName + ": " + countExceeds + " values exceed maxRelError=" + maxRelErrorOverride + " with minAbsError=" + minAbsErrorOverride + "; largest observed relError=" + maxRE, 0, countExceeds); } else { // assertEquals("Value not equal on node " + varName, tfValue, sdVal); if(tfValue.equals(sdVal)){ System.out.println("Pass: " + varName); } else { System.out.println("FAIL: " + varName); System.out.println("TF:\n" + tfValue); System.out.println("SD:\n" + sdVal); } } log.info("Values and shapes equal for {}", varName); count++; } } } assertTrue("No intermediate variables were checked", count > 0); } Nd4j.EPS_THRESHOLD = 1e-5; }