Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#isScalar()
The following examples show how to use
org.nd4j.linalg.api.ndarray.INDArray#isScalar() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: RegressionMetrics.java From konduit-serving with Apache License 2.0 | 6 votes |
private void handleNdArray(INDArray output) { if(output.isVector()) { for(int i = 0; i < output.length(); i++) { statCounters.get(i).add(output.getDouble(i)); } } else if(output.isMatrix() && output.length() > 1) { for(int i = 0; i < output.rows(); i++) { for(int j = 0; j < output.columns(); j++) { statCounters.get(i).add(output.getDouble(i,j)); } } } else if(output.isScalar()) { statCounters.get(0).add(output.sumNumber().doubleValue()); } else { throw new IllegalArgumentException("Only vectors and matrices supported right now"); } }
Example 2
Source File: BaseComplexNDArray.java From nd4j with Apache License 2.0 | 6 votes |
/** * in place (element wise) multiplication of two matrices * * @param other the second ndarray to multiply * @param result the result ndarray * @return the result of the multiplication */ @Override public IComplexNDArray muli(INDArray other, INDArray result) { IComplexNDArray cOther = (IComplexNDArray) other; IComplexNDArray cResult = (IComplexNDArray) result; IComplexNDArray linear = linearView(); IComplexNDArray cOtherLinear = cOther.linearView(); IComplexNDArray cResultLinear = cResult.linearView(); if (other.isScalar()) return muli(cOther.getComplex(0), result); IComplexNumber c = Nd4j.createComplexNumber(0, 0); IComplexNumber d = Nd4j.createComplexNumber(0, 0); for (int i = 0; i < length(); i++) cResultLinear.putScalar(i, linear.getComplex(i, c).muli(cOtherLinear.getComplex(i, d))); return cResult; }
Example 3
Source File: CpuNDArrayFactory.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public INDArray sort(INDArray x, boolean descending, int... dimension) { if (x.isScalar()) return x; Arrays.sort(dimension); Pair<DataBuffer, DataBuffer> tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension); NativeOpsHolder.getInstance().getDeviceNativeOps().sortTad(null, x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, null, (IntPointer) Nd4j.getConstantHandler().getConstantBuffer(dimension, DataType.INT).addressPointer(), dimension.length, (LongPointer) tadBuffers.getFirst().addressPointer(), new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), descending); return x; }
Example 4
Source File: BaseNDArrayFactory.java From nd4j with Apache License 2.0 | 6 votes |
/** * Generate a linearly spaced vector * * @param lower upper bound * @param upper lower bound * @param num the step size * @return the linearly spaced vector */ @Override public INDArray linspace(int lower, int upper, int num) { double[] data = new double[num]; for (int i = 0; i < num; i++) { double t = (double) i / (num - 1); data[i] = lower * (1 - t) + t * upper; } //edge case for scalars INDArray ret = Nd4j.create(data.length); if (ret.isScalar()) return ret; for (int i = 0; i < ret.length(); i++) ret.putScalar(i, data[i]); return ret; }
Example 5
Source File: NDArrayPreconditionsFormat.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public String format(String tag, Object arg) { if(arg == null) return "null"; INDArray arr = (INDArray)arg; switch (tag){ case "%ndRank": return String.valueOf(arr.rank()); case "%ndShape": return Arrays.toString(arr.shape()); case "%ndStride": return Arrays.toString(arr.stride()); case "%ndLength": return String.valueOf(arr.length()); case "%ndSInfo": return arr.shapeInfoToString().replaceAll("\n",""); case "%nd10": if(arr.isScalar() || arr.isEmpty()){ return arr.toString(); } INDArray sub = arr.reshape(arr.length()).get(NDArrayIndex.interval(0, Math.min(arr.length(), 10))); return sub.toString(); default: //Should never happen throw new IllegalStateException("Unknown format tag: " + tag); } }
Example 6
Source File: FirstAxisIterator.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public Object next() { INDArray s = iterateOver.slice(i++); if (s.isScalar()) { return s.getDouble(0); } else { return s; } }
Example 7
Source File: BaseComplexNDArray.java From nd4j with Apache License 2.0 | 5 votes |
/** * Copy imaginary numbers to the given * ndarray * @param arr the array to copy imaginary numbers to */ protected void copyImagTo(INDArray arr) { INDArray linear = arr.linearView(); IComplexNDArray thisLinear = linearView(); if (arr.isScalar()) arr.putScalar(0, getReal(0)); else for (int i = 0; i < linear.length(); i++) { arr.putScalar(i, thisLinear.getImag(i)); } }
Example 8
Source File: BaseComplexNDArray.java From nd4j with Apache License 2.0 | 5 votes |
/** * Copy real numbers to arr * @param arr the arr to copy to */ protected void copyRealTo(INDArray arr) { INDArray linear = arr.linearView(); IComplexNDArray thisLinear = linearView(); if (arr.isScalar()) arr.putScalar(0, getReal(0)); else for (int i = 0; i < linear.length(); i++) { arr.putScalar(i, thisLinear.getReal(i)); } }
Example 9
Source File: JCublasNDArrayFactory.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public INDArray sort(INDArray x, boolean descending, int... dimension) { if (x.isScalar()) return x; Arrays.sort(dimension); Nd4j.getExecutioner().push(); val tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension); val context = AtomicAllocator.getInstance().getFlowController().prepareAction(x); val extraz = new PointerPointer(AtomicAllocator.getInstance().getHostPointer(x.shapeInfoDataBuffer()), // not used context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); val dimensionPointer = AtomicAllocator.getInstance() .getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension)); nativeOps.sortTad(extraz, null, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(x, context), (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), (IntPointer) dimensionPointer, dimension.length, (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)), descending ); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); AtomicAllocator.getInstance().getFlowController().registerAction(context, x); return x; }
Example 10
Source File: CpuNDArrayFactory.java From nd4j with Apache License 2.0 | 5 votes |
@Override public INDArray sort(INDArray x, boolean descending) { if (x.isScalar()) return x; if (x.data().dataType() == DataBuffer.Type.FLOAT) { NativeOpsHolder.getInstance().getDeviceNativeOps().sortFloat(null, (FloatPointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), descending); } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) { NativeOpsHolder.getInstance().getDeviceNativeOps().sortDouble(null, (DoublePointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), descending); } else { throw new UnsupportedOperationException("Unknown dataype " + x.data().dataType()); } return x; }
Example 11
Source File: CpuSparseNDArrayFactory.java From nd4j with Apache License 2.0 | 5 votes |
@Override public INDArray sort(INDArray x, boolean descending, int... dimension) { if (x.isScalar()) return x; Arrays.sort(dimension); Pair<DataBuffer, DataBuffer> tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension); if (x.data().dataType() == DataBuffer.Type.FLOAT) { NativeOpsHolder.getInstance().getDeviceNativeOps().sortTadFloat(null, (FloatPointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length, (LongPointer) tadBuffers.getFirst().addressPointer(), new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), descending); } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) { NativeOpsHolder.getInstance().getDeviceNativeOps().sortTadDouble(null, (DoublePointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length, (LongPointer) tadBuffers.getFirst().addressPointer(), new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), descending); } else { throw new UnsupportedOperationException("Unknown datatype " + x.data().dataType()); } return x; }
Example 12
Source File: CpuSparseNDArrayFactory.java From nd4j with Apache License 2.0 | 5 votes |
@Override public INDArray sort(INDArray x, boolean descending) { if (x.isScalar()) return x; if (x.data().dataType() == DataBuffer.Type.FLOAT) { NativeOpsHolder.getInstance().getDeviceNativeOps().sortFloat(null, (FloatPointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), descending); } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) { NativeOpsHolder.getInstance().getDeviceNativeOps().sortDouble(null, (DoublePointer) x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), descending); } else { throw new UnsupportedOperationException("Unknown dataype " + x.data().dataType()); } return x; }
Example 13
Source File: CoverageModelEMWorkspaceMathUtils.java From gatk-protected with BSD 3-Clause "New" or "Revised" License | 5 votes |
/** * Solves a linear system using Apache commons methods [mat].[x] = [vec] * * @param mat the coefficients matrix (must be square and full-rank) * @param vec the right hand side vector * @param singularityThreshold a threshold for detecting singularity * @return solution of the linear system */ public static INDArray linsolve(@Nonnull final INDArray mat, @Nonnull final INDArray vec, final double singularityThreshold) { if (mat.isScalar()) { return vec.div(mat.getDouble(0)); } if (!mat.isSquare()) { throw new IllegalArgumentException("invalid array: must be a square matrix"); } final RealVector sol = new LUDecomposition(Nd4jApacheAdapterUtils.convertINDArrayToApacheMatrix(mat), singularityThreshold).getSolver().solve(Nd4jApacheAdapterUtils.convertINDArrayToApacheVector(vec)); return Nd4j.create(sol.toArray(), vec.shape()); }
Example 14
Source File: CoverageModelEMWorkspaceMathUtils.java From gatk-protected with BSD 3-Clause "New" or "Revised" License | 5 votes |
/** * Calculates log abs determinant of a matrix via LU decomposition. * * @param mat a square matrix * @return log abs determinant of {@code mat} */ public static double logdet(@Nonnull final INDArray mat) { if (mat.isScalar()) { return FastMath.log(FastMath.abs(mat.getDouble(0))); } if (!mat.isSquare()) { throw new IllegalArgumentException("Invalid array: must be square matrix"); } final LUDecomposition decomp = new LUDecomposition(Nd4jApacheAdapterUtils.convertINDArrayToApacheMatrix(mat), DEFAULT_LU_DECOMPOSITION_SINGULARITY_THRESHOLD); final double[] diagL = diag(decomp.getL()); final double[] diagU = diag(decomp.getU()); return Arrays.stream(diagL).map(FastMath::abs).map(FastMath::log).sum() + Arrays.stream(diagU).map(FastMath::abs).map(FastMath::log).sum(); }
Example 15
Source File: CpuNDArrayFactory.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public INDArray sort(INDArray x, boolean descending) { if (x.isScalar()) return x; NativeOpsHolder.getInstance().getDeviceNativeOps().sort(null, x.data().addressPointer(), (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, null, descending); return x; }
Example 16
Source File: NativeOpExecutioner.java From deeplearning4j with Apache License 2.0 | 4 votes |
public INDArray exec(IndexAccumulation op, OpContext oc) { checkForCompression(op); INDArray x = getX(op, oc); INDArray z = getZ(op, oc); if (extraz.get() == null) extraz.set(new PointerPointer(32)); val dimension = Shape.normalizeAxis(x.rank(), op.dimensions().toIntVector()); if (x.isEmpty()) { for (val d:dimension) { Preconditions.checkArgument(x.shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape"); } } boolean keepDims = op.isKeepDims(); long[] retShape = Shape.reductionShape(x, dimension, true, keepDims); if(z == null || x == z) { val ret = Nd4j.createUninitialized(DataType.LONG, retShape); setZ(ret, op, oc); z = ret; } else if(!Arrays.equals(retShape, z.shape())){ throw new IllegalStateException("Z array shape does not match expected return type for op " + op + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape())); } op.validateDataTypes(); Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(x, dimension); Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer(); DataBuffer offsets = tadBuffers.getSecond(); Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer(); PointerPointer dummy = extraz.get().put(hostTadShapeInfo, hostTadOffsets); long st = profilingConfigurableHookIn(op, tadBuffers.getFirst()); val xb = ((BaseCpuDataBuffer) x.data()).getOpaqueDataBuffer(); val zb = ((BaseCpuDataBuffer) z.data()).getOpaqueDataBuffer(); if (z.isScalar()) { loop.execIndexReduceScalar(dummy, op.opNum(), xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, x.dataType()), zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null); } else { loop.execIndexReduce(dummy, op.opNum(), xb, (LongPointer) x.shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, x.dataType()), zb, (LongPointer) z.shapeInfoDataBuffer().addressPointer(), null, ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); } if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); profilingConfigurableHookOut(op, oc, st); return getZ(op, oc); }
Example 17
Source File: JCublasNDArrayFactory.java From nd4j with Apache License 2.0 | 4 votes |
@Override public INDArray sort(INDArray x, boolean descending, int... dimension) { if (x.isScalar()) return x; Arrays.sort(dimension); Nd4j.getExecutioner().push(); Pair<DataBuffer, DataBuffer> tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension); CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(x); PointerPointer extraz = new PointerPointer(AtomicAllocator.getInstance().getHostPointer(x.shapeInfoDataBuffer()), // not used context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); Pointer dimensionPointer = AtomicAllocator.getInstance() .getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); if (x.data().dataType() == DataBuffer.Type.FLOAT) { nativeOps.sortTadFloat(extraz, (FloatPointer) AtomicAllocator.getInstance().getPointer(x, context), (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), (IntPointer) dimensionPointer, dimension.length, (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)), descending ); } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) { nativeOps.sortTadDouble(extraz, (DoublePointer) AtomicAllocator.getInstance().getPointer(x, context), (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), (IntPointer) dimensionPointer, dimension.length, (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)), descending ); } else if (x.data().dataType() == DataBuffer.Type.HALF) { nativeOps.sortTadHalf(extraz, (ShortPointer) AtomicAllocator.getInstance().getPointer(x, context), (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), (IntPointer) dimensionPointer, dimension.length, (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context), new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)), descending ); } else { throw new UnsupportedOperationException("Unknown dataType " + x.data().dataType()); } AtomicAllocator.getInstance().getFlowController().registerAction(context, x); return x; }
Example 18
Source File: JCublasNDArrayFactory.java From nd4j with Apache License 2.0 | 4 votes |
@Override public INDArray sort(INDArray x, boolean descending) { if (x.isScalar()) return x; Nd4j.getExecutioner().push(); CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(x); Pointer ptr = AtomicAllocator.getInstance().getHostPointer(x.shapeInfoDataBuffer()); PointerPointer extraz = new PointerPointer(ptr, // 0 context.getOldStream(), // 1 AtomicAllocator.getInstance().getDeviceIdPointer(), // 2 context.getBufferAllocation(), // 3 context.getBufferReduction(), // 4 context.getBufferScalar(), // 5 context.getBufferSpecial(), // 6 ptr, // 7 AtomicAllocator.getInstance().getHostPointer(x.shapeInfoDataBuffer()), // 8 ptr, // 9 ptr, // 10 ptr, // 11 ptr, // 12 ptr, // 13 ptr, // 14 ptr, // special pointer for IsMax // 15 ptr, // special pointer for IsMax // 16 ptr, // special pointer for IsMax // 17 new CudaPointer(0)); // we're sending > 10m elements to radixSort boolean isRadix = !x.isView() && (x.lengthLong() > 1024 * 1024 * 10); INDArray tmpX = x; // we need to guarantee all threads are finished here if (isRadix) Nd4j.getExecutioner().commit(); if (x.data().dataType() == DataBuffer.Type.FLOAT) { nativeOps.sortFloat(extraz, (FloatPointer) AtomicAllocator.getInstance().getPointer(tmpX, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tmpX.shapeInfoDataBuffer(), context), descending ); } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) { nativeOps.sortDouble(extraz, (DoublePointer) AtomicAllocator.getInstance().getPointer(tmpX, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tmpX.shapeInfoDataBuffer(), context), descending ); } else if (x.data().dataType() == DataBuffer.Type.HALF) { nativeOps.sortHalf(extraz, (ShortPointer) AtomicAllocator.getInstance().getPointer(tmpX, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tmpX.shapeInfoDataBuffer(), context), descending ); } else { throw new UnsupportedOperationException("Unknown dataType " + x.data().dataType()); } AtomicAllocator.getInstance().getFlowController().registerAction(context, x); return x; }
Example 19
Source File: BaseComplexNDArray.java From nd4j with Apache License 2.0 | 4 votes |
/** * Perform an copy matrix multiplication * * @param other the other matrix to perform matrix multiply with * @param result the result ndarray * @return the result of the matrix multiplication */ @Override public IComplexNDArray mmuli(INDArray other, INDArray result) { IComplexNDArray otherArray = (IComplexNDArray) other; IComplexNDArray resultArray = (IComplexNDArray) result; if (other.shape().length > 2) { for (int i = 0; i < other.slices(); i++) { resultArray.putSlice(i, slice(i).mmul(otherArray.slice(i))); } return resultArray; } LinAlgExceptions.assertMultiplies(this, other); if (other.isScalar()) { return muli(otherArray.getComplex(0), resultArray); } if (isScalar()) { return otherArray.muli(getComplex(0), resultArray); } /* check sizes and resize if necessary */ //assertMultipliesWith(other); if (result == this || result == other) { /* actually, blas cannot do multiplications in-place. Therefore, we will fake by * allocating a temporary object on the side and copy the result later. */ IComplexNDArray temp = Nd4j.createComplex(resultArray.shape()); if (otherArray.columns() == 1) { Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(temp), BlasBufferUtil.getCharForTranspose(this), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, temp); } else { Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(temp), BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(other), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, temp); } Nd4j.getBlasWrapper().copy(temp, resultArray); } else { if (otherArray.columns() == 1) { Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(resultArray), BlasBufferUtil.getCharForTranspose(this), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, resultArray); } else { Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(resultArray), BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(other), Nd4j.UNIT, this, otherArray, Nd4j.ZERO, resultArray); } } return resultArray; }
Example 20
Source File: MtcnnService.java From mtcnn-java with Apache License 2.0 | 4 votes |
/** * STAGE 2 * * @param image * @param totalBoxes * @param padResult * @return * @throws IOException */ private INDArray refinementStage(INDArray image, INDArray totalBoxes, MtcnnUtil.PadResult padResult) throws IOException { // num_boxes = total_boxes.shape[0] int numBoxes = totalBoxes.isEmpty() ? 0 : (int) totalBoxes.shape()[0]; // if num_boxes == 0: // return total_boxes, stage_status if (numBoxes == 0) { return totalBoxes; } INDArray tempImg1 = computeTempImage(image, numBoxes, padResult, 24); //this.refineNetGraph.associateArrayWithVariable(tempImg1, this.refineNetGraph.variableMap().get("rnet/input")); //List<DifferentialFunction> refineNetResults = this.refineNetGraph.exec().getRight(); //INDArray out0 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/fc2-2/fc2-2")) // .findFirst().get().outputVariable().getArr(); //INDArray out1 = refineNetResults.stream().filter(df -> df.getOwnName().equalsIgnoreCase("rnet/prob1")) // .findFirst().get().outputVariable().getArr(); Map<String, INDArray> resultMap = this.refineNetGraphRunner.run(Collections.singletonMap("rnet/input", tempImg1)); //INDArray out0 = resultMap.get("rnet/fc2-2/fc2-2"); // for ipazc/mtcnn model INDArray out0 = resultMap.get("rnet/conv5-2/conv5-2"); INDArray out1 = resultMap.get("rnet/prob1"); // score = out1[1, :] INDArray score = out1.get(all(), point(1)).transposei(); // ipass = np.where(score > self.__steps_threshold[1]) INDArray ipass = MtcnnUtil.getIndexWhereVector(score.transpose(), s -> s > stepsThreshold[1]); //INDArray ipass = MtcnnUtil.getIndexWhereVector2(score.transpose(), Conditions.greaterThan(stepsThreshold[1])); if (ipass.isEmpty()) { totalBoxes = Nd4j.empty(); return totalBoxes; } // total_boxes = np.hstack([total_boxes[ipass[0], 0:4].copy(), np.expand_dims(score[ipass].copy(), 1)]) INDArray b1 = totalBoxes.get(new SpecifiedIndex(ipass.toLongVector()), interval(0, 4)); INDArray b2 = ipass.isScalar() ? score.get(ipass).reshape(1, 1) : Nd4j.expandDims(score.get(ipass), 1); totalBoxes = Nd4j.hstack(b1, b2); // mv = out0[:, ipass[0]] INDArray mv = out0.get(new SpecifiedIndex(ipass.toLongVector()), all()).transposei(); // if total_boxes.shape[0] > 0: if (!totalBoxes.isEmpty() && totalBoxes.shape()[0] > 0) { // pick = self.__nms(total_boxes, 0.7, 'Union') INDArray pick = MtcnnUtil.nonMaxSuppression(totalBoxes.dup(), 0.7, MtcnnUtil.NonMaxSuppressionType.Union).transpose(); // total_boxes = total_boxes[pick, :] totalBoxes = totalBoxes.get(new SpecifiedIndex(pick.toLongVector()), all()); // total_boxes = self.__bbreg(total_boxes.copy(), np.transpose(mv[:, pick])) totalBoxes = MtcnnUtil.bbreg(totalBoxes, mv.get(all(), new SpecifiedIndex(pick.toLongVector())).transpose()); // total_boxes = self.__rerec(total_boxes.copy()) totalBoxes = MtcnnUtil.rerec(totalBoxes, false); } return totalBoxes; }