Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#getInt()
The following examples show how to use
org.nd4j.linalg.api.ndarray.INDArray#getInt() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ImageUtilities.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 6 votes |
/** * Takes an INDArray containing an image loaded using the native image loader * libraries associated with DL4J, and converts it into a BufferedImage. * The INDArray contains the color values split up across three channels (RGB) * and in the integer range 0-255. * * @param array INDArray containing an image * @return BufferedImage */ public BufferedImage imageFromINDArray(INDArray array) { int[] shape = array.shape(); int height = shape[2]; int width = shape[3]; BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB); for (int x = 0; x < width; x++) { for (int y = 0; y < height; y++) { int red = array.getInt(0, 2, y, x); int green = array.getInt(0, 1, y, x); int blue = array.getInt(0, 0, y, x); //handle out of bounds pixel values red = Math.min(red, 255); green = Math.min(green, 255); blue = Math.min(blue, 255); red = Math.max(red, 0); green = Math.max(green, 0); blue = Math.max(blue, 0); image.setRGB(x, y, new Color(red, green, blue).getRGB()); } } return image; }
Example 2
Source File: NeuralStyleTransfer.java From dl4j-tutorials with MIT License | 6 votes |
/** * Takes an INDArray containing an image loaded using the native image loader * libraries associated with DL4J, and converts it into a BufferedImage. * The INDArray contains the color values split up across three channels (RGB) * and in the integer range 0-255. * * @param array INDArray containing an image * @return BufferedImage */ private BufferedImage imageFromINDArray(INDArray array) { long[] shape = array.shape(); long height = shape[2]; long width = shape[3]; BufferedImage image = new BufferedImage((int)width, (int)height, BufferedImage.TYPE_INT_RGB); for (int x = 0; x < width; x++) { for (int y = 0; y < height; y++) { int red = array.getInt(0, 2, y, x); int green = array.getInt(0, 1, y, x); int blue = array.getInt(0, 0, y, x); //handle out of bounds pixel values red = Math.min(red, 255); green = Math.min(green, 255); blue = Math.min(blue, 255); red = Math.max(red, 0); green = Math.max(green, 0); blue = Math.max(blue, 0); image.setRGB(x, y, new Color(red, green, blue).getRGB()); } } return image; }
Example 3
Source File: UsingModelToPredict.java From dl4j-tutorials with MIT License | 6 votes |
/** * 将单通道的 INDArray 保存为灰度图 * * There's also NativeImageLoader.asMat(INDArray) and we can then use OpenCV to save it as an image file. * * @param array 输入 * @return 灰度图转化 */ private static BufferedImage imageFromINDArray(INDArray array) { long[] shape = array.shape(); int height = (int)shape[2]; int width = (int)shape[3]; BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY); for (int x = 0; x < width; x++) { for (int y = 0; y < height; y++) { int gray = array.getInt(0, 0, y, x); // handle out of bounds pixel values gray = Math.min(gray, 255); gray = Math.max(gray, 0); image.getRaster().setSample(x, y, 0, gray); } } return image; }
Example 4
Source File: QuadTree.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * * @param rowP a vector * @param colP * @param valP * @param N * @param posF */ public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) { if (!rowP.isVector()) throw new IllegalArgumentException("RowP must be a vector"); // Loop over all edges in the graph double D; for (int n = 0; n < N; n++) { for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { // Compute pairwise distance and Q-value buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i))); D = Nd4j.getBlasWrapper().dot(buf, buf); D = valP.getDouble(i) / D; // Sum positive force posF.slice(n).addi(buf.mul(D)); } } }
Example 5
Source File: NDArrayCreationUtil.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * Create an ndarray * of * @param seed * @param rank * @param numShapes * @return */ public static int[][] getRandomBroadCastShape(long seed, int rank, int numShapes) { Nd4j.getRandom().setSeed(seed); INDArray coinFlip = Nd4j.getDistributions().createBinomial(1, 0.5).sample(new int[] {numShapes, rank}); int[][] ret = new int[(int) coinFlip.rows()][(int) coinFlip.columns()]; for (int i = 0; i < coinFlip.rows(); i++) { for (int j = 0; j < coinFlip.columns(); j++) { int set = coinFlip.getInt(i, j); if (set > 0) ret[i][j] = set; else { //anything from 0 to 9 ret[i][j] = Nd4j.getRandom().nextInt(9) + 1; } } } return ret; }
Example 6
Source File: BaseLapack.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public INDArray getPFactor(int M, INDArray ipiv) { // The simplest permutation is the identity matrix INDArray P = Nd4j.eye(M); // result is a square matrix with given size for (int i = 0; i < ipiv.length(); i++) { int pivot = ipiv.getInt(i) - 1; // Did we swap row #i with anything? if (pivot > i) { // don't reswap when we get lower down in the vector INDArray v1 = P.getColumn(i).dup(); // because of row vs col major order we'll ... INDArray v2 = P.getColumn(pivot); // ... make a transposed matrix immediately P.putColumn(i, v2); P.putColumn(pivot, v1); // note dup() above is required - getColumn() is a 'view' } } return P; // the permutation matrix - contains a single 1 in any row and column }
Example 7
Source File: DL4JSequenceRecommender.java From inception with Apache License 2.0 | 5 votes |
private <T extends Sample> List<Outcome<T>> predict(MultiLayerNetwork aClassifier, String[] aTagset, List<T> aData) throws IOException { if (aData.isEmpty()) { return Collections.emptyList(); } DataSet data = vectorize(aData); // Predict labels long predictionStart = System.currentTimeMillis(); INDArray predicted = aClassifier.output(data.getFeatures(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray()); log.trace("Prediction took {}ms", System.currentTimeMillis() - predictionStart); // This is a brute-force hack to ensue that argmax doesn't predict tags that are not // in the tagset. Actually, this should be necessary at all if the network is properly // configured... predicted = predicted.get(NDArrayIndex.all(), NDArrayIndex.interval(0, aTagset.length), NDArrayIndex.all()); List<Outcome<T>> outcomes = new ArrayList<>(); int sampleIdx = 0; for (Sample sample : aData) { INDArray argMax = Nd4j.argMax(predicted, 1); List<String> tokens = sample.getSentence(); String[] labels = new String[tokens.size()]; for (int tokenIdx = 0; tokenIdx < tokens.size(); tokenIdx ++) { labels[tokenIdx] = aTagset[argMax.getInt(sampleIdx, tokenIdx)]; } outcomes.add(new Outcome(sample, asList(labels))); sampleIdx ++; } return outcomes; }
Example 8
Source File: RandomProjectionLSH.java From deeplearning4j with Apache License 2.0 | 5 votes |
INDArray bucketData(INDArray query){ INDArray mask = bucket(query); int nRes = mask.sum(0).getInt(0); INDArray res = Nd4j.create(new int[] {nRes, inDimension}); int j = 0; for (int i = 0; i < nRes; i++){ while (mask.getInt(j) == 0 && j < mask.length() - 1) { j += 1; } if (mask.getInt(j) == 1) res.putRow(i, indexData.getRow(j)); j += 1; } return res; }
Example 9
Source File: BaseLapack.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void gesvd(INDArray A, INDArray S, INDArray U, INDArray VT) { if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); int m = (int) A.rows(); int n = (int) A.columns(); byte jobu = (byte) (U == null ? 'N' : 'A'); byte jobvt = (byte) (VT == null ? 'N' : 'A'); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {1, 1}, DataType.INT).getFirst()); if (A.data().dataType() == DataType.DOUBLE) dgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO); else if (A.data().dataType() == DataType.FLOAT) sgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO); else throw new UnsupportedOperationException(); if (INFO.getInt(0) < 0) { throw new Error("Parameter #" + INFO.getInt(0) + " to gesvd() was not valid"); } else if (INFO.getInt(0) > 0) { log.warn("The matrix contains singular elements. Check S matrix at row " + INFO.getInt(0)); } }
Example 10
Source File: BaseLapack.java From nd4j with Apache License 2.0 | 5 votes |
@Override public INDArray getPFactor(int M, INDArray ipiv) { // The simplest permutation is the identity matrix INDArray P = Nd4j.eye(M); // result is a square matrix with given size for (int i = 0; i < ipiv.length(); i++) { int pivot = ipiv.getInt(i) - 1; // Did we swap row #i with anything? if (pivot > i) { // don't reswap when we get lower down in the vector INDArray v1 = P.getColumn(i).dup(); // because of row vs col major order we'll ... INDArray v2 = P.getColumn(pivot); // ... make a transposed matrix immediately P.putColumn(i, v2); P.putColumn(pivot, v1); // note dup() above is required - getColumn() is a 'view' } } return P; // the permutation matrix - contains a single 1 in any row and column }
Example 11
Source File: BaseLapack.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void gesvd(INDArray A, INDArray S, INDArray U, INDArray VT) { // FIXME: int cast if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); int m = (int) A.rows(); int n = (int) A.columns(); byte jobu = (byte) (U == null ? 'N' : 'A'); byte jobvt = (byte) (VT == null ? 'N' : 'A'); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst()); if (A.data().dataType() == DataBuffer.Type.DOUBLE) dgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO); else if (A.data().dataType() == DataBuffer.Type.FLOAT) sgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO); else throw new UnsupportedOperationException(); if (INFO.getInt(0) < 0) { throw new Error("Parameter #" + INFO.getInt(0) + " to gesvd() was not valid"); } else if (INFO.getInt(0) > 0) { log.warn("The matrix contains singular elements. Check S matrix at row " + INFO.getInt(0)); } }
Example 12
Source File: TopK.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { String thisName = nodeDef.getName(); // FIXME: ???? String inputName = thisName + "/k"; NodeDef kNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(inputName)){ kNode = graph.getNode(i); break; } } this.sorted = nodeDef.getAttrOrThrow("sorted").getB(); if (kNode != null) { Preconditions.checkState(kNode != null, "Could not find 'k' parameter node for op: %s", thisName); INDArray arr = TFGraphMapper.getNDArrayFromTensor(kNode); this.k = arr.getInt(0); addIArgument(ArrayUtil.fromBoolean(sorted), k); } else addIArgument(ArrayUtil.fromBoolean(sorted)); }
Example 13
Source File: BaseLapack.java From nd4j with Apache License 2.0 | 5 votes |
@Override public INDArray getrf(INDArray A) { // FIXME: int cast if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); int m = (int) A.rows(); int n = (int) A.columns(); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst()); int mn = Math.min(m, n); INDArray IPIV = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(mn), Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, mn}).getFirst()); if (A.data().dataType() == DataBuffer.Type.DOUBLE) dgetrf(m, n, A, IPIV, INFO); else if (A.data().dataType() == DataBuffer.Type.FLOAT) sgetrf(m, n, A, IPIV, INFO); else throw new UnsupportedOperationException(); if (INFO.getInt(0) < 0) { throw new Error("Parameter #" + INFO.getInt(0) + " to getrf() was not valid"); } else if (INFO.getInt(0) > 0) { log.warn("The matrix is singular - cannot be used for inverse op. Check L matrix at row " + INFO.getInt(0)); } return IPIV; }
Example 14
Source File: BaseLapack.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void potrf(INDArray A, boolean lower) { if (A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); byte uplo = (byte) (lower ? 'L' : 'U'); // upper or lower part of the factor desired ? int n = (int) A.columns(); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {1, 1}, A.dataType()).getFirst()); if (A.data().dataType() == DataType.DOUBLE) dpotrf(uplo, n, A, INFO); else if (A.data().dataType() == DataType.FLOAT) spotrf(uplo, n, A, INFO); else throw new UnsupportedOperationException(); if (INFO.getInt(0) < 0) { throw new Error("Parameter #" + INFO.getInt(0) + " to potrf() was not valid"); } else if (INFO.getInt(0) > 0) { throw new Error("The matrix is not positive definite! (potrf fails @ order " + INFO.getInt(0) + ")"); } return; }
Example 15
Source File: BaseLapack.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public INDArray getrf(INDArray A) { if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); int m = (int) A.rows(); int n = (int) A.columns(); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {1, 1}, A.dataType()).getFirst()); int mn = Math.min(m, n); INDArray IPIV = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(mn), Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {1, mn}, A.dataType()).getFirst()); if (A.data().dataType() == DataType.DOUBLE) dgetrf(m, n, A, IPIV, INFO); else if (A.data().dataType() == DataType.FLOAT) sgetrf(m, n, A, IPIV, INFO); else throw new UnsupportedOperationException(); if (INFO.getInt(0) < 0) { throw new Error("Parameter #" + INFO.getInt(0) + " to getrf() was not valid"); } else if (INFO.getInt(0) > 0) { log.warn("The matrix is singular - cannot be used for inverse op. Check L matrix at row " + INFO.getInt(0)); } return IPIV; }
Example 16
Source File: TestImageLoader.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void testToINDArrayBGR() throws Exception { BufferedImage img = makeRandomBufferedImage(false); int w = img.getWidth(); int h = img.getHeight(); int ch = 3; ImageLoader loader = new ImageLoader(0, 0, ch); INDArray arr = loader.toINDArrayBGR(img); long[] shape = arr.shape(); assertEquals(3, shape.length); assertEquals(ch, shape[0]); assertEquals(h, shape[1]); assertEquals(w, shape[2]); for (int i = 0; i < h; ++i) { for (int j = 0; j < w; ++j) { int srcColor = img.getRGB(j, i); int a = 0xff << 24; int r = arr.getInt(2, i, j) << 16; int g = arr.getInt(1, i, j) << 8; int b = arr.getInt(0, i, j) & 0xff; int dstColor = a | r | g | b; assertEquals(srcColor, dstColor); } } }
Example 17
Source File: ValidateZooModelPredictions.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testMobilenetV1() throws Exception { if(TFGraphTestZooModels.isPPC()){ /* Ugly hack to temporarily disable tests on PPC only on CI Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657 These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions */ log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657"); OpValidationSuite.ignoreFailing(); } TFGraphTestZooModels.currentTestDir = testDir.newFolder(); //Load model String path = "tf_graphs/zoo_models/mobilenet_v1_0.5_128/tf_model.txt"; File resource = new ClassPathResource(path).getFile(); SameDiff sd = TFGraphTestZooModels.LOADER.apply(resource, "mobilenet_v1_0.5_128"); //Load data //Because we don't have DataVec NativeImageLoader in ND4J tests due to circular dependencies, we'll load the image previously saved... File imgFile = new ClassPathResource("deeplearning4j-zoo/goldenretriever_rgb128_unnormalized_nchw_INDArray.bin").getFile(); INDArray img = Nd4j.readBinary(imgFile).castTo(DataType.FLOAT); img = img.permute(0,2,3,1).dup(); //to NHWC //Mobilenet V1 - not sure, but probably using inception preprocessing //i.e., scale to (-1,1) range //Image is originally 0 to 255 img.divi(255).subi(0.5).muli(2); double min = img.minNumber().doubleValue(); double max = img.maxNumber().doubleValue(); assertTrue(min >= -1 && min <= -0.6); assertTrue(max <= 1 && max >= 0.6); //Perform inference List<String> inputs = sd.inputs(); assertEquals(1, inputs.size()); String out = "MobilenetV1/Predictions/Softmax"; Map<String,INDArray> m = sd.output(Collections.singletonMap(inputs.get(0), img), out); INDArray outArr = m.get(out); System.out.println("SHAPE: " + Arrays.toString(outArr.shape())); System.out.println(outArr); INDArray argmax = outArr.argMax(1); //Load labels List<String> labels = labels(); int classIdx = argmax.getInt(0); String className = labels.get(classIdx); String expClass = "golden retriever"; double prob = outArr.getDouble(classIdx); System.out.println("Predicted class: \"" + className + "\" - probability = " + prob); assertEquals(expClass, className); }
Example 18
Source File: JcublasLapack.java From nd4j with Apache License 2.0 | 4 votes |
public int dsyev( char _jobz, char _uplo, int N, INDArray A, INDArray R ) { int status = -1 ; int jobz = _jobz == 'V' ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR ; int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER ; if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) log.warn("DOUBLE dsyev called in FLOAT environment"); INDArray a = A; if (A.ordering() == 'c') a = A.dup('f'); // FIXME: int cast int M = (int) A.rows() ; if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); cusolverDnContext solverDn = new cusolverDnContext(handle); // synchronized on the solver synchronized (handle) { status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); if( status == 0 ) { // transfer the INDArray into GPU memory CublasPointer xAPointer = new CublasPointer(a, ctx); CublasPointer xRPointer = new CublasPointer(R, ctx); // this output - indicates how much memory we'll need for the real operation DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); status = cusolverDnDsyevd_bufferSize( solverDn, jobz, uplo, M, (DoublePointer) xAPointer.getDevicePointer(), M, (DoublePointer) xRPointer.getDevicePointer(), (IntPointer)worksizeBuffer.addressPointer() ) ; if (status == CUSOLVER_STATUS_SUCCESS) { int worksize = worksizeBuffer.getInt(0); // allocate memory for the workspace, the non-converging row buffer and a return code Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1})); // Do the actual decomp status = cusolverDnDsyevd(solverDn, jobz, uplo, M, (DoublePointer) xAPointer.getDevicePointer(), M, (DoublePointer) xRPointer.getDevicePointer(), new CudaPointer(workspace).asDoublePointer(), worksize, new CudaPointer(allocator.getPointer(INFO, ctx)).asIntPointer()); allocator.registerAction(ctx, INFO); if( status == 0 ) status = INFO.getInt(0) ; } } } if( status == 0 ) { allocator.registerAction(ctx, R); allocator.registerAction(ctx, a); if (a != A) A.assign(a); } return status ; }
Example 19
Source File: ValidateZooModelPredictions.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testResnetV2() throws Exception { if(TFGraphTestZooModels.isPPC()){ /* Ugly hack to temporarily disable tests on PPC only on CI Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657 These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions */ log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657"); OpValidationSuite.ignoreFailing(); } TFGraphTestZooModels.currentTestDir = testDir.newFolder(); //Load model String path = "tf_graphs/zoo_models/resnetv2_imagenet_frozen_graph/tf_model.txt"; File resource = new ClassPathResource(path).getFile(); SameDiff sd = TFGraphTestZooModels.LOADER.apply(resource, "resnetv2_imagenet_frozen_graph"); //Load data //Because we don't have DataVec NativeImageLoader in ND4J tests due to circular dependencies, we'll load the image previously saved... File imgFile = new ClassPathResource("deeplearning4j-zoo/goldenretriever_rgb224_unnormalized_nchw_INDArray.bin").getFile(); INDArray img = Nd4j.readBinary(imgFile).castTo(DataType.FLOAT); img = img.permute(0,2,3,1).dup(); //to NHWC //Resnet v2 - NO external normalization, just resize and center crop // https://github.com/tensorflow/models/blob/d32d957a02f5cffb745a4da0d78f8432e2c52fd4/research/tensorrt/tensorrt.py#L70 // https://github.com/tensorflow/models/blob/1af55e018eebce03fb61bba9959a04672536107d/official/resnet/imagenet_preprocessing.py#L253-L256 //Perform inference List<String> inputs = sd.inputs(); assertEquals(1, inputs.size()); String out = "softmax_tensor"; Map<String,INDArray> m = sd.output(Collections.singletonMap(inputs.get(0), img), out); INDArray outArr = m.get(out); System.out.println("SHAPE: " + Arrays.toString(outArr.shape())); System.out.println(outArr); INDArray argmax = outArr.argMax(1); //Load labels List<String> labels = labels(); int classIdx = argmax.getInt(0); String className = labels.get(classIdx); String expClass = "golden retriever"; double prob = outArr.getDouble(classIdx); System.out.println("Predicted class: " + classIdx + " - \"" + className + "\" - probability = " + prob); assertEquals(expClass, className); }
Example 20
Source File: Eye.java From deeplearning4j with Apache License 2.0 | 4 votes |
public Eye(@NonNull INDArray rows){ this(rows.getInt(0)); Preconditions.checkArgument(rows.isScalar(), "Rows INDArray must be a scalar"); }