Java Code Examples for org.nd4j.linalg.factory.Nd4j#argMax()
The following examples show how to use
org.nd4j.linalg.factory.Nd4j#argMax() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: CudaIndexReduceTests.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testIMaxF1() throws Exception { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.rand('f',10,2); for( int i=0; i<10; i++ ){ INDArray row = arr.getRow(i); int maxIdx; if(row.getDouble(0) > row.getDouble(1)) maxIdx = 0; else maxIdx = 1; INDArray argmax = Nd4j.argMax(row,1); double argmaxd = argmax.getDouble(0); assertEquals(maxIdx, (int)argmaxd); System.out.println(row); System.out.println(argmax); System.out.println("exp: " + maxIdx + ", act: " + argmaxd); } }
Example 2
Source File: CudaIndexReduceTests.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testIMaxDimensional() throws Exception { INDArray toArgMax = Nd4j.linspace(1,24,24).reshape(4, 3, 2); INDArray valueArray = Nd4j.valueArrayOf(new int[]{4, 2}, 2.0); INDArray valueArrayTwo = Nd4j.valueArrayOf(new int[]{3,2},3.0); INDArray valueArrayThree = Nd4j.valueArrayOf(new int[]{4,3},1.0); INDArray argMax = Nd4j.argMax(toArgMax, 1); assertEquals(valueArray, argMax); INDArray argMaxZero = Nd4j.argMax(toArgMax,0); assertEquals(valueArrayTwo, argMaxZero); INDArray argMaxTwo = Nd4j.argMax(toArgMax,2); assertEquals(valueArrayThree,argMaxTwo); }
Example 3
Source File: ArgmaxAdapter.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * This method does conversion from INDArrays to int[], where each element will represents position of the highest element in output INDArray * I.e. Array of {0.25, 0.1, 0.5, 0.15} will return int array with length of 1, and value {2} * * @param outputs * @return */ @Override public int[] apply(INDArray... outputs) { Preconditions.checkArgument(outputs.length == 1, "Argmax adapter can have only 1 output"); val array = outputs[0]; Preconditions.checkArgument(array.rank() < 3, "Argmax adapter requires 2D or 1D output"); val result = array.rank() == 2 ? new int[(int) array.size(0)] : new int[1]; if (array.rank() == 2) { val t = Nd4j.argMax(array, 1); for (int e = 0; e < t.length(); e++) result[e] = (int) t.getDouble(e); } else result[0] = (int) Nd4j.argMax(array, Integer.MAX_VALUE).getDouble(0); return result; }
Example 4
Source File: EndlessTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testIndexAccumForeverAlongDimension(){ INDArray arr = Nd4j.ones(100,100); for (int i = 0; i < RUN_LIMIT; i++ ) { Nd4j.argMax(arr,0); } }
Example 5
Source File: EndlessTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testIndexAccumForeverAlongDimension(){ INDArray arr = Nd4j.ones(100,100); for (int i = 0; i < RUN_LIMIT; i++ ) { Nd4j.argMax(arr,0); } }
Example 6
Source File: CudaIndexReduceTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testIMax4() { INDArray array1 = Nd4j.linspace(1, 1000, 128000).reshape(128, 1000); long time1 = System.currentTimeMillis(); INDArray argMax = Nd4j.argMax(array1, 0,1); long time2 = System.currentTimeMillis(); System.out.println("Execution time: " + (time2 - time1)); assertEquals(127999f, argMax.getFloat(0), 0.001f); }
Example 7
Source File: CudaIndexReduceTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testIMax2() { INDArray array1 = Nd4j.linspace(1, 1000, 128000).reshape(128, 1000); long time1 = System.currentTimeMillis(); INDArray argMax = Nd4j.argMax(array1, 1); long time2 = System.currentTimeMillis(); System.out.println("Execution time: " + (time2 - time1)); for (int i = 0; i < 128; i++) { assertEquals(999f, argMax.getFloat(i), 0.0001f); } }
Example 8
Source File: EvaluationCalibrationTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testLabelAndPredictionCounts() { int minibatch = 50; int nClasses = 3; INDArray arr = Nd4j.rand(minibatch, nClasses); arr.diviColumnVector(arr.sum(1)); INDArray labels = Nd4j.zeros(minibatch, nClasses); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { labels.putScalar(i, r.nextInt(nClasses), 1.0); } EvaluationCalibration ec = new EvaluationCalibration(5, 5); ec.eval(labels, arr); int[] expLabelCounts = labels.sum(0).data().asInt(); int[] expPredictionCount = new int[(int) labels.size(1)]; INDArray argmax = Nd4j.argMax(arr, 1); for (int i = 0; i < argmax.length(); i++) { expPredictionCount[argmax.getInt(i)]++; } assertArrayEquals(expLabelCounts, ec.getLabelCountsEachClass()); assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass()); }
Example 9
Source File: ClassifierOutputAdapter.java From konduit-serving with Apache License 2.0 | 5 votes |
@Override public ClassifierOutput adapt(INDArray array, RoutingContext routingContext) { INDArray argMax = Nd4j.argMax(array, -1); return ClassifierOutput.builder() .labels(getLabels()) .decisions(argMax.data().asInt()) .probabilities(array.isVector() ? new double[][]{array.toDoubleVector()} : array.toDoubleMatrix()) .build(); }
Example 10
Source File: EndlessTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testIndexAccumForeverFull(){ INDArray arr = Nd4j.ones(100,100); for (int i = 0; i < RUN_LIMIT; i++ ) { Nd4j.argMax(arr,Integer.MAX_VALUE); } }
Example 11
Source File: NativeOpExecutionerTest.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testArgMax1() { INDArray array1 = Nd4j.create(new float[]{-1.0f, 2.0f}); INDArray array2 = Nd4j.create(new float[]{2.0f}); INDArray res = Nd4j.argMax(array1); System.out.println("Res length: " + res.length()); assertEquals(1.0f, res.getFloat(0), 0.01f); System.out.println("--------------------"); res = Nd4j.argMax(array2); System.out.println("Res length: " + res.length()); assertEquals(0.0f, res.getFloat(0), 0.01f); }
Example 12
Source File: DL4JSequenceRecommender.java From inception with Apache License 2.0 | 5 votes |
private <T extends Sample> List<Outcome<T>> predict(MultiLayerNetwork aClassifier, String[] aTagset, List<T> aData) throws IOException { if (aData.isEmpty()) { return Collections.emptyList(); } DataSet data = vectorize(aData); // Predict labels long predictionStart = System.currentTimeMillis(); INDArray predicted = aClassifier.output(data.getFeatures(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray()); log.trace("Prediction took {}ms", System.currentTimeMillis() - predictionStart); // This is a brute-force hack to ensue that argmax doesn't predict tags that are not // in the tagset. Actually, this should be necessary at all if the network is properly // configured... predicted = predicted.get(NDArrayIndex.all(), NDArrayIndex.interval(0, aTagset.length), NDArrayIndex.all()); List<Outcome<T>> outcomes = new ArrayList<>(); int sampleIdx = 0; for (Sample sample : aData) { INDArray argMax = Nd4j.argMax(predicted, 1); List<String> tokens = sample.getSentence(); String[] labels = new String[tokens.size()]; for (int tokenIdx = 0; tokenIdx < tokens.size(); tokenIdx ++) { labels[tokenIdx] = aTagset[argMax.getInt(sampleIdx, tokenIdx)]; } outcomes.add(new Outcome(sample, asList(labels))); sampleIdx ++; } return outcomes; }
Example 13
Source File: EndlessTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testIndexAccumForeverAlongDimensions(){ INDArray arr = Nd4j.linspace(1, 10000, 10000).reshape(10, 10, 100); for (int i = 0; i < RUN_LIMIT; i++ ) { Nd4j.argMax(arr,0, 1); } }
Example 14
Source File: MultiLabelMetrics.java From konduit-serving with Apache License 2.0 | 5 votes |
private void incrementClassificationCounters(INDArray[] outputs) { INDArray argMax = Nd4j.argMax(outputs[0], -1); for(int i = 0; i < argMax.length(); i++) { CurrentClassTrackerCount classTrackerCount = classTrackerCounts.get(argMax.getInt(i)); classTrackerCount.increment(1.0); } }
Example 15
Source File: ClassificationMetrics.java From konduit-serving with Apache License 2.0 | 5 votes |
private void handleNdArray(INDArray array) { INDArray argMax = Nd4j.argMax(array, -1); for(int i = 0; i < argMax.length(); i++) { CurrentClassTrackerCount classTrackerCount = classTrackerCounts.get(argMax.getInt(i)); classTrackerCount.increment(1.0); } }
Example 16
Source File: RnnSequenceClassifier.java From wekaDeeplearning4j with GNU General Public License v3.0 | 4 votes |
/** * The method to use when making predictions for test instances. * * @param insts the instances to get predictions for * @return the class probability estimates (if the class is nominal) or the numeric predictions * (if it is numeric) * @throws Exception if something goes wrong at prediction time */ @Override public double[][] distributionsForInstances(Instances insts) throws Exception { log.info("Calc. dist for {} instances", insts.numInstances()); // Do we only have a ZeroR model? if (zeroR != null) { return zeroR.distributionsForInstances(insts); } // Process input data to have the same filters applied as the training data insts = applyFilters(insts); // Get predictions final DataSetIterator it = getDataSetIterator(insts, CacheMode.NONE); double[][] preds = new double[insts.numInstances()][insts.numClasses()]; if (it.resetSupported()) { it.reset(); } int offset = 0; boolean next = it.hasNext(); // Get predictions batch-wise while (next) { final DataSet ds = Utils.getNext(it); final INDArray features = ds.getFeatures(); final INDArray labelsMask = ds.getLabelsMaskArray(); INDArray lastTimeStepIndices; if (labelsMask != null) { lastTimeStepIndices = Nd4j.argMax(labelsMask, 1); } else { lastTimeStepIndices = Nd4j.zeros(features.size(0), 1); } INDArray predBatch = model.outputSingle(features); int currentBatchSize = (int) predBatch.size(0); for (int i = 0; i < currentBatchSize; i++) { int thisTimeSeriesLastIndex = lastTimeStepIndices.getInt(i); INDArray thisExampleProbabilities = predBatch.get( NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(thisTimeSeriesLastIndex)); for (int j = 0; j < insts.numClasses(); j++) { preds[i + offset][j] = thisExampleProbabilities.getDouble(j); } } offset += currentBatchSize; // add batchsize as offset boolean iteratorHasInstancesLeft = offset < insts.numInstances(); next = it.hasNext() || iteratorHasInstancesLeft; } // Fix classes for (int i = 0; i < preds.length; i++) { if (preds[i].length > 1) { weka.core.Utils.normalize(preds[i]); } else { // Rescale numeric classes with the computed coefficients in the initialization phase preds[i][0] = preds[i][0] * x1 + x0; } } return preds; }
Example 17
Source File: DoubleDQN.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override protected void initComputation(INDArray observations, INDArray nextObservations) { super.initComputation(observations, nextObservations); maxActionsFromQNetworkNextObservation = Nd4j.argMax(qNetworkNextObservation, ACTION_DIMENSION_IDX); }
Example 18
Source File: SameDiffTests.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testArgMin() { Nd4j.getRandom().setSeed(12345); for (val dim : new int[][]{{0}, {1}, {Integer.MAX_VALUE}, {0, 1}, {}}) { INDArray inArr = Nd4j.rand(3, 4); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable argmin = sd.argmin("argmin", in, dim); INDArray out = sd.execAndEndResult(); INDArray exp = Nd4j.argMax(inArr.neg(), dim); //argmin(x) == argmax(-x) assertEquals(exp, out); } }
Example 19
Source File: LastTimeStepVertex.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { //First: get the mask arrays for the given input, if any INDArray[] inputMaskArrays = graph.getInputMaskArrays(); INDArray mask = (inputMaskArrays != null ? inputMaskArrays[inputIdx] : null); //Then: work out, from the mask array, which time step of activations we want, extract activations //Also: record where they came from (so we can do errors later) fwdPassShape = inputs[0].shape(); INDArray out; if (mask == null) { //No mask array -> extract same (last) column for all long lastTS = inputs[0].size(2) - 1; out = inputs[0].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(lastTS)); out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out); fwdPassTimeSteps = null; //Null -> last time step for all examples } else { val outShape = new long[] {inputs[0].size(0), inputs[0].size(1)}; out = workspaceMgr.create(ArrayType.ACTIVATIONS, inputs[0].dataType(), outShape); //Want the index of the last non-zero entry in the mask array. //Check a little here by using mulRowVector([0,1,2,3,...]) and argmax long maxTsLength = fwdPassShape[2]; INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength, mask.dataType()); INDArray temp = mask.mulRowVector(row); INDArray lastElementIdx = Nd4j.argMax(temp, 1); fwdPassTimeSteps = new int[(int)fwdPassShape[0]]; for (int i = 0; i < fwdPassTimeSteps.length; i++) { fwdPassTimeSteps[i] = (int) lastElementIdx.getDouble(i); } //Now, get and assign the corresponding subsets of 3d activations: for (int i = 0; i < fwdPassTimeSteps.length; i++) { out.putRow(i, inputs[0].get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(fwdPassTimeSteps[i]))); } } return out; }
Example 20
Source File: BaseNDArray.java From nd4j with Apache License 2.0 | 2 votes |
/** * This method returns index of highest value along specified dimension(s) * * @param dimension * @return */ @Override public INDArray argMax(int... dimension) { return Nd4j.argMax(this, dimension); }