org.nd4j.linalg.indexing.SpecifiedIndex Java Examples
The following examples show how to use
org.nd4j.linalg.indexing.SpecifiedIndex.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SlicingTestsC.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testGetRow() { INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray get = arr.getRow(1); INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all()); INDArray assertion = Nd4j.create(new double[] {4, 5, 6}); assertEquals(assertion, get); assertEquals(get, get2); get2.assign(Nd4j.linspace(1, 3, 3)); assertEquals(Nd4j.linspace(1, 3, 3), get2); INDArray threeByThree = Nd4j.linspace(1, 9, 9).reshape(3, 3); INDArray offsetTest = threeByThree.get(new SpecifiedIndex(1, 2), NDArrayIndex.all()); INDArray threeByThreeAssertion = Nd4j.create(new double[][] {{4, 5, 6}, {7, 8, 9}}); assertEquals(threeByThreeAssertion, offsetTest); }
Example #2
Source File: SlicingTestsC.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testGetRow() { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray get = arr.getRow(1); INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all()); INDArray assertion = Nd4j.create(new double[] {4, 5, 6}); assertEquals(assertion, get); assertEquals(get, get2); get2.assign(Nd4j.linspace(1, 3, 3, DataType.DOUBLE)); assertEquals(Nd4j.linspace(1, 3, 3, DataType.DOUBLE), get2); INDArray threeByThree = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray offsetTest = threeByThree.get(new SpecifiedIndex(1, 2), NDArrayIndex.all()); INDArray threeByThreeAssertion = Nd4j.create(new double[][] {{4, 5, 6}, {7, 8, 9}}); assertEquals(threeByThreeAssertion, offsetTest); }
Example #3
Source File: IndexingTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testIntervalLowerBound() { INDArray wholeArr = Nd4j.linspace(1, 24, 24).reshape(4, 2, 3); INDArray subarray = wholeArr.get(interval(1, 3), new SpecifiedIndex(new int[] {0}), new SpecifiedIndex(new int[] {0, 2})); INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}}); assertEquals(assertion, subarray); }
Example #4
Source File: IndexingTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testSpecifiedIndexVector() { INDArray rootMatrix = Nd4j.linspace(1, 16, 16).reshape(4, 4); INDArray threeD = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2)); INDArray assertion = Nd4j.create(new double[][] {{1, 3}, {5, 7}, {9, 11}, {13, 15}}); assertEquals(assertion, get); INDArray assertion2 = Nd4j.create(new double[][] {{1, 3, 4}, {5, 7, 8}, {9, 11, 12}, {13, 15, 16}}); INDArray get2 = rootMatrix.get(all(), new SpecifiedIndex(0, 2, 3)); assertEquals(assertion2, get2); }
Example #5
Source File: IndexingTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testGetRows() { INDArray arr = Nd4j.linspace(1, 9, 9).reshape(3, 3); INDArray testAssertion = Nd4j.create(new double[][] {{4, 5}, {7, 8}}); INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(0, 1)); assertEquals(testAssertion, test); }
Example #6
Source File: IndexingTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testMultiRow() { INDArray matrix = Nd4j.linspace(1, 9, 9).reshape(3, 3); INDArray assertion = Nd4j.create(new double[][] {{4, 7}}); INDArray test = matrix.get(new SpecifiedIndex(1, 2), NDArrayIndex.interval(0, 1)); assertEquals(assertion, test); }
Example #7
Source File: IndexingTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testGetRows() { INDArray arr = Nd4j.linspace(1, 9, 9).reshape(3, 3); INDArray testAssertion = Nd4j.create(new double[][] {{5, 8}, {6, 9}}); INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(1, 2)); assertEquals(testAssertion, test); }
Example #8
Source File: BaseSparseNDArray.java From nd4j with Apache License 2.0 | 5 votes |
@Override public INDArray put(INDArray indices, INDArray element) { INDArrayIndex[] realIndices = new INDArrayIndex[indices.rank()]; for(int i = 0; i < realIndices.length; i++) { realIndices[i] = new SpecifiedIndex(indices.slice(i).dup().data().asInt()); } return put(realIndices,element); }
Example #9
Source File: IndexingTestsC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testGetRows() { INDArray arr = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray testAssertion = Nd4j.create(new double[][] {{4, 5}, {7, 8}}); INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(0, 1)); assertEquals(testAssertion, test); }
Example #10
Source File: IndexingTestsC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testMultiRow() { INDArray matrix = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray assertion = Nd4j.create(new double[][] {{4, 7}}); INDArray test = matrix.get(new SpecifiedIndex(1, 2), NDArrayIndex.interval(0, 1)); assertEquals(assertion, test); }
Example #11
Source File: IndexingTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testGetRows() { INDArray arr = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray testAssertion = Nd4j.create(new double[][] {{5, 8}, {6, 9}}); INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(1, 2)); assertEquals(testAssertion, test); }
Example #12
Source File: MtcnnService.java From mtcnn-java with Apache License 2.0 | 4 votes |
/** * STAGE 2 * * @param image * @param totalBoxes * @param padResult * @return * @throws IOException */ private INDArray refinementStage(INDArray image, INDArray totalBoxes, MtcnnUtil.PadResult padResult) throws IOException { // num_boxes = total_boxes.shape[0] int numBoxes = totalBoxes.isEmpty() ? 0 : (int) totalBoxes.shape()[0]; // if num_boxes == 0: // return total_boxes, stage_status if (numBoxes == 0) { return totalBoxes; } INDArray tempImg1 = computeTempImage(image, numBoxes, padResult, 24); //this.refineNetGraph.associateArrayWithVariable(tempImg1, this.refineNetGraph.variableMap().get("rnet/input")); //List<DifferentialFunction> refineNetResults = this.refineNetGraph.exec().getRight(); //INDArray out0 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/fc2-2/fc2-2")) // .findFirst().get().outputVariable().getArr(); //INDArray out1 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/prob1")) // .findFirst().get().outputVariable().getArr(); Map<String, INDArray> resultMap = this.refineNetGraphRunner.run(Collections.singletonMap("rnet/input", tempImg1)); //INDArray out0 = resultMap.get("rnet/fc2-2/fc2-2"); // for ipazc/mtcnn model INDArray out0 = resultMap.get("rnet/conv5-2/conv5-2"); INDArray out1 = resultMap.get("rnet/prob1"); // score = out1[1, :] INDArray score = out1.get(all(), point(1)).transposei(); // ipass = np.where(score > self.__steps_threshold[1]) INDArray ipass = MtcnnUtil.getIndexWhereVector(score.transpose(), s -> s > stepsThreshold[1]); //INDArray ipass = MtcnnUtil.getIndexWhereVector2(score.transpose(), Conditions.greaterThan(stepsThreshold[1])); if (ipass.isEmpty()) { totalBoxes = Nd4j.empty(); return totalBoxes; } // total_boxes = np.hstack([total_boxes[ipass[0], 0:4].copy(), np.expand_dims(score[ipass].copy(), 1)]) INDArray b1 = totalBoxes.get(new SpecifiedIndex(ipass.toLongVector()), interval(0, 4)); INDArray b2 = ipass.isScalar() ? score.get(ipass).reshape(1, 1) : Nd4j.expandDims(score.get(ipass), 1); totalBoxes = Nd4j.hstack(b1, b2); // mv = out0[:, ipass[0]] INDArray mv = out0.get(new SpecifiedIndex(ipass.toLongVector()), all()).transposei(); // if total_boxes.shape[0] > 0: if (!totalBoxes.isEmpty() && totalBoxes.shape()[0] > 0) { // pick = self.__nms(total_boxes, 0.7, 'Union') INDArray pick = MtcnnUtil.nonMaxSuppression(totalBoxes.dup(), 0.7, MtcnnUtil.NonMaxSuppressionType.Union).transpose(); // total_boxes = total_boxes[pick, :] totalBoxes = totalBoxes.get(new SpecifiedIndex(pick.toLongVector()), all()); // total_boxes = self.__bbreg(total_boxes.copy(), np.transpose(mv[:, pick])) totalBoxes = MtcnnUtil.bbreg(totalBoxes, mv.get(all(), new SpecifiedIndex(pick.toLongVector())).transpose()); // total_boxes = self.__rerec(total_boxes.copy()) totalBoxes = MtcnnUtil.rerec(totalBoxes, false); } return totalBoxes; }
Example #13
Source File: MtcnnUtil.java From mtcnn-java with Apache License 2.0 | 4 votes |
private static INDArrayIndex[] toUpdateIndex(INDArray array) { return new INDArrayIndex[] { new SpecifiedIndex(array.toLongVector()) }; }
Example #14
Source File: BaseSparseNDArray.java From nd4j with Apache License 2.0 | 4 votes |
@Override public INDArray put(List<List<Integer>> indices, INDArray element) { if(indices.size() == rank()) { NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape()); INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()]; for(int i = 0; i < indArrayIndices.length; i++) { indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i))); } boolean hasNext = true; Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices); while(hasNext) { try { List<List<Long>> next = iterate.next(); for(int i = 0; i < next.size(); i++) { int[] curr = Ints.toArray(next.get(i)); putScalar(curr,element.getDouble(ndIndexIterator.next())); } } catch(NoSuchElementException e) { hasNext = false; } } } else { List<INDArray> arrList = new ArrayList<>(); if(indices.size() >= 2) { for(int i = 0; i < indices.size(); i++) { List<Integer> row = indices.get(i); for(int j = 0; j < row.size(); j++) { INDArray slice = slice(row.get(j)); Nd4j.getExecutioner().exec(new Assign(new INDArray[]{slice,element},new INDArray[]{slice})); arrList.add(slice(row.get(j))); } } } else if(indices.size() == 1) { for(int i = 0; i < indices.size(); i++) { arrList.add(slice(indices.get(0).get(i))); } } } return this; }