Java Code Examples for org.nd4j.linalg.api.buffer.DataType#INT
The following examples show how to use
org.nd4j.linalg.api.buffer.DataType#INT .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestNativeImageLoader.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testDataTypes_2() throws Exception { val dtypes = new DataType[]{DataType.FLOAT, DataType.HALF, DataType.SHORT, DataType.INT}; val dt = Nd4j.dataType(); for (val dtype: dtypes) { Nd4j.setDataType(dtype); int w3 = 123, h3 = 77, ch3 = 3; val loader = new NativeImageLoader(h3, w3, 1); File f3 = new ClassPathResource("datavec-data-image/testimages/class0/2.jpg").getFile(); val array = loader.asMatrix(f3); assertEquals(dtype, array.dataType()); } Nd4j.setDataType(dt); }
Example 2
Source File: SpecialTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void reproduceWorkspaceCrash_3(){ val conf = WorkspaceConfiguration.builder().build(); val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS"); val dtypes = new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, DataType.UBYTE, DataType.BOOL}; for (val dX : dtypes) { for (val dZ: dtypes) { try(val ws2 = ws.notifyScopeEntered()) { val array = Nd4j.create(dX, 2, 5).assign(1); // log.info("Trying to cast {} to {}", dX, dZ); val casted = array.castTo(dZ); val exp = Nd4j.create(dZ, 2, 5).assign(1); assertEquals(exp, casted); Nd4j.getExecutioner().commit(); } } } }
Example 3
Source File: MixedDataTypesTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testSimple(){ Nd4j.create(1); for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) { // System.out.println("----- " + dt + " -----"); INDArray arr = Nd4j.ones(dt,1, 5); // System.out.println("Ones: " + arr); arr.assign(1.0); // System.out.println("assign(1.0): " + arr); // System.out.println("DIV: " + arr.div(8)); // System.out.println("MUL: " + arr.mul(8)); // System.out.println("SUB: " + arr.sub(8)); // System.out.println("ADD: " + arr.add(8)); // System.out.println("RDIV: " + arr.rdiv(8)); // System.out.println("RSUB: " + arr.rsub(8)); arr.div(8); arr.mul(8); arr.sub(8); arr.add(8); arr.rdiv(8); arr.rsub(8); } }
Example 4
Source File: TADTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * this method compares rank, shape and stride for two given shapeBuffers * @param shapeA * @param shapeB * @return */ protected boolean compareShapes(@NonNull DataBuffer shapeA, @NonNull DataBuffer shapeB) { if (shapeA.dataType() != DataType.INT) throw new IllegalStateException("ShapeBuffer should have dataType of INT"); if (shapeA.dataType() != shapeB.dataType()) return false; int rank = shapeA.getInt(0); if (rank != shapeB.getInt(0)) return false; for (int e = 1; e <= rank * 2; e++) { if (shapeA.getInt(e) != shapeB.getInt(e)) return false; } return true; }
Example 5
Source File: ToStringTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testToStringScalars(){ DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32}; String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"}; for(int dt=0; dt<5; dt++ ) { for (int i = 0; i < 5; i++) { long[] shape = ArrayUtil.nTimes(i, 1L); INDArray scalar = Nd4j.scalar(1.0f).castTo(dataTypes[dt]).reshape(shape); String str = scalar.toString(); StringBuilder sb = new StringBuilder(); for (int j = 0; j < i; j++) { sb.append("["); } sb.append(strs[dt]); for (int j = 0; j < i; j++) { sb.append("]"); } String exp = sb.toString(); assertEquals("Rank: " + i + ", DT: " + dataTypes[dt], exp, str); } } }
Example 6
Source File: IndexedTail.java From deeplearning4j with Apache License 2.0 | 6 votes |
protected INDArray smartDecompress(INDArray encoded, @NonNull INDArray target) { INDArray result = target; if (encoded.isCompressed() || encoded.data().dataType() == DataType.INT) { int encoding = encoded.data().getInt(3); if (encoding == ThresholdCompression.FLEXIBLE_ENCODING) { Nd4j.getExecutioner().thresholdDecode(encoded, result); } else if (encoding == ThresholdCompression.BITMAP_ENCODING) { Nd4j.getExecutioner().bitmapDecode(encoded, result); } else throw new ND4JIllegalStateException("Unknown encoding mode: [" + encoding + "]"); } else { result.addi(encoded); } return result; }
Example 7
Source File: ArrayOptionsHelper.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static DataType dataType(long opt) { if (hasBitSet(opt, DTYPE_COMPRESSED_BIT)) return DataType.COMPRESSED; else if (hasBitSet(opt, DTYPE_HALF_BIT)) return DataType.HALF; else if (hasBitSet(opt, DTYPE_BFLOAT16_BIT)) return DataType.BFLOAT16; else if (hasBitSet(opt, DTYPE_FLOAT_BIT)) return DataType.FLOAT; else if (hasBitSet(opt, DTYPE_DOUBLE_BIT)) return DataType.DOUBLE; else if (hasBitSet(opt, DTYPE_INT_BIT)) return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UINT32 : DataType.INT; else if (hasBitSet(opt, DTYPE_LONG_BIT)) return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UINT64 : DataType.LONG; else if (hasBitSet(opt, DTYPE_BOOL_BIT)) return DataType.BOOL; else if (hasBitSet(opt, DTYPE_BYTE_BIT)) { return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UBYTE : DataType.BYTE; //Byte bit set for both UBYTE and BYTE } else if (hasBitSet(opt, DTYPE_SHORT_BIT)) return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UINT16 : DataType.SHORT; else if (hasBitSet(opt, DTYPE_UTF8_BIT)) return DataType.UTF8; else throw new ND4JUnknownDataTypeException("Unknown extras set: [" + opt + "]"); }
Example 8
Source File: SmartFancyBlockingQueue.java From deeplearning4j with Apache License 2.0 | 6 votes |
protected INDArray smartDecompress(INDArray encoded, INDArray target) { INDArray result = target == null ? Nd4j.create(paramsShape, paramsOrder) : target; if (encoded.isCompressed() || encoded.data().dataType() == DataType.INT) { int encoding = encoded.data().getInt(3); if (encoding == ThresholdCompression.FLEXIBLE_ENCODING) { Nd4j.getExecutioner().thresholdDecode(encoded, result); } else if (encoding == ThresholdCompression.BITMAP_ENCODING) { Nd4j.getExecutioner().bitmapDecode(encoded, result); } else throw new ND4JIllegalStateException("Unknown encoding mode: [" + encoding + "]"); } else { result.addi(encoded); } return result; }
Example 9
Source File: TensorflowConversion.java From deeplearning4j with Apache License 2.0 | 6 votes |
private DataType typeFor(int tensorflowType) { switch(tensorflowType) { case DT_DOUBLE: return DataType.DOUBLE; case DT_FLOAT: return DataType.FLOAT; case DT_HALF: return DataType.HALF; case DT_INT16: return DataType.SHORT; case DT_INT32: return DataType.INT; case DT_INT64: return DataType.LONG; case DT_STRING: return DataType.UTF8; case DT_INT8: return DataType.BYTE; case DT_UINT8: return DataType.UBYTE; case DT_UINT16: return DataType.UINT16; case DT_UINT32: return DataType.UINT32; case DT_UINT64: return DataType.UINT64; case DT_BFLOAT16: return DataType.BFLOAT16; case DT_BOOL: return DataType.BOOL; default: throw new IllegalArgumentException("Illegal type " + tensorflowType); } }
Example 10
Source File: ArrowSerde.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * Create thee databuffer type frm the given type, * relative to the bytes in arrow in class: * {@link Type} * @param type the type to create the nd4j {@link DataType} from * @param elementSize the element size * @return the data buffer type */ public static DataType typeFromTensorType(byte type, int elementSize) { if(type == Type.FloatingPoint) { return DataType.FLOAT; } else if(type == Type.Decimal) { return DataType.DOUBLE; } else if(type == Type.Int) { if(elementSize == 4) { return DataType.INT; } else if(elementSize == 8) { return DataType.LONG; } } else { throw new IllegalArgumentException("Only valid types are Type.Decimal and Type.Int"); } throw new IllegalArgumentException("Unable to determine data type"); }
Example 11
Source File: PythonUtils.java From deeplearning4j with Apache License 2.0 | 5 votes |
public static NumpyArray mapToNumpyArray(Map map) { String dtypeName = (String) map.get("dtype"); DataType dtype; if (dtypeName.equals("float64")) { dtype = DataType.DOUBLE; } else if (dtypeName.equals("float32")) { dtype = DataType.FLOAT; } else if (dtypeName.equals("int16")) { dtype = DataType.SHORT; } else if (dtypeName.equals("int32")) { dtype = DataType.INT; } else if (dtypeName.equals("int64")) { dtype = DataType.LONG; } else { throw new RuntimeException("Unsupported array type " + dtypeName + "."); } List shapeList = (List) map.get("shape"); long[] shape = new long[shapeList.size()]; for (int i = 0; i < shape.length; i++) { shape[i] = (Long) shapeList.get(i); } List strideList = (List) map.get("shape"); long[] stride = new long[strideList.size()]; for (int i = 0; i < stride.length; i++) { stride[i] = (Long) strideList.get(i); } long address = (Long) map.get("address"); NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true); return numpyArray; }
Example 12
Source File: SpecialTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void reproduceWorkspaceCrash_2(){ val dtypes = new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, DataType.UBYTE, DataType.BOOL}; for (val dX : dtypes) { for (val dZ: dtypes) { val array = Nd4j.create(dX, 2, 5).assign(1); // log.info("Trying to cast {} to {}", dX, dZ); val casted = array.castTo(dZ); val exp = Nd4j.create(dZ, 2, 5).assign(1); assertEquals(exp, casted); } } }
Example 13
Source File: ArrayOptionsHelper.java From deeplearning4j with Apache License 2.0 | 5 votes |
public static DataType convertToDataType(org.tensorflow.framework.DataType dataType) { switch (dataType) { case DT_UINT16: return DataType.UINT16; case DT_UINT32: return DataType.UINT32; case DT_UINT64: return DataType.UINT64; case DT_BOOL: return DataType.BOOL; case DT_BFLOAT16: return DataType.BFLOAT16; case DT_FLOAT: return DataType.FLOAT; case DT_INT32: return DataType.INT; case DT_INT64: return DataType.LONG; case DT_INT8: return DataType.BYTE; case DT_INT16: return DataType.SHORT; case DT_DOUBLE: return DataType.DOUBLE; case DT_UINT8: return DataType.UBYTE; case DT_HALF: return DataType.HALF; case DT_STRING: return DataType.UTF8; default: throw new UnsupportedOperationException("Unknown TF data type: [" + dataType.name() + "]"); } }
Example 14
Source File: IntBuffer.java From deeplearning4j with Apache License 2.0 | 4 votes |
/** * Initialize the opType of this buffer */ @Override protected void initTypeAndSize() { elementSize = 4; type = DataType.INT; }
Example 15
Source File: EvaluationCalibrationTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testReliabilityDiagram() { DataType dtypeBefore = Nd4j.defaultFloatingPointType(); EvaluationCalibration first = null; String sFirst = null; try { for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { //Test using 5 bins - format: binary softmax-style output //Note: no values fall in fourth bin //[0, 0.2) INDArray bin0Probs = Nd4j.create(new double[][]{{1.0, 0.0}, {0.9, 0.1}, {0.85, 0.15}}).castTo(lpDtype); INDArray bin0Labels = Nd4j.create(new double[][]{{1.0, 0.0}, {1.0, 0.0}, {0.0, 1.0}}).castTo(lpDtype); //[0.2, 0.4) INDArray bin1Probs = Nd4j.create(new double[][]{{0.80, 0.20}, {0.7, 0.3}, {0.65, 0.35}}).castTo(lpDtype); INDArray bin1Labels = Nd4j.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}}).castTo(lpDtype); //[0.4, 0.6) INDArray bin2Probs = Nd4j.create(new double[][]{{0.59, 0.41}, {0.5, 0.5}, {0.45, 0.55}}).castTo(lpDtype); INDArray bin2Labels = Nd4j.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}, {0.0, 1.0}}).castTo(lpDtype); //[0.6, 0.8) //Empty //[0.8, 1.0] INDArray bin4Probs = Nd4j.create(new double[][]{{0.0, 1.0}, {0.1, 0.9}}).castTo(lpDtype); INDArray bin4Labels = Nd4j.create(new double[][]{{0.0, 1.0}, {0.0, 1.0}}).castTo(lpDtype); INDArray probs = Nd4j.vstack(bin0Probs, bin1Probs, bin2Probs, bin4Probs); INDArray labels = Nd4j.vstack(bin0Labels, bin1Labels, bin2Labels, bin4Labels); EvaluationCalibration ec = new EvaluationCalibration(5, 5); ec.eval(labels, probs); for (int i = 0; i < 1; i++) { double[] avgBinProbsClass; double[] fracPos; if (i == 0) { //Class 0: needs to be handled a little differently, due to threshold/edge cases (0.8, etc) avgBinProbsClass = new double[]{0.05, (0.59 + 0.5 + 0.45) / 3, (0.65 + 0.7) / 2.0, (0.8 + 0.85 + 0.9 + 1.0) / 4}; fracPos = new double[]{0.0 / 2.0, 1.0 / 3, 1.0 / 2, 3.0 / 4}; } else { avgBinProbsClass = new double[]{bin0Probs.getColumn(i).meanNumber().doubleValue(), bin1Probs.getColumn(i).meanNumber().doubleValue(), bin2Probs.getColumn(i).meanNumber().doubleValue(), bin4Probs.getColumn(i).meanNumber().doubleValue()}; fracPos = new double[]{bin0Labels.getColumn(i).sumNumber().doubleValue() / bin0Labels.size(0), bin1Labels.getColumn(i).sumNumber().doubleValue() / bin1Labels.size(0), bin2Labels.getColumn(i).sumNumber().doubleValue() / bin2Labels.size(0), bin4Labels.getColumn(i).sumNumber().doubleValue() / bin4Labels.size(0)}; } org.nd4j.evaluation.curves.ReliabilityDiagram rd = ec.getReliabilityDiagram(i); double[] x = rd.getMeanPredictedValueX(); double[] y = rd.getFractionPositivesY(); assertArrayEquals(avgBinProbsClass, x, 1e-3); assertArrayEquals(fracPos, y, 1e-3); String s = ec.stats(); if(first == null) { first = ec; sFirst = s; } else { // assertEquals(first, ec); assertEquals(sFirst, s); assertTrue(first.getRDiagBinPosCount().equalsWithEps(ec.getRDiagBinPosCount(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); //Lower precision due to fload assertTrue(first.getRDiagBinTotalCount().equalsWithEps(ec.getRDiagBinTotalCount(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); assertTrue(first.getRDiagBinSumPredictions().equalsWithEps(ec.getRDiagBinSumPredictions(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); assertArrayEquals(first.getLabelCountsEachClass(), ec.getLabelCountsEachClass()); assertArrayEquals(first.getPredictionCountsEachClass(), ec.getPredictionCountsEachClass()); assertTrue(first.getProbHistogramOverall().equalsWithEps(ec.getProbHistogramOverall(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); assertTrue(first.getProbHistogramByLabelClass().equalsWithEps(ec.getProbHistogramByLabelClass(), lpDtype == DataType.HALF ? 1e-3 : 1e-5)); } } } } } finally { Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); } }
Example 16
Source File: EvalTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testEval2() { DataType dtypeBefore = Nd4j.defaultFloatingPointType(); Evaluation first = null; String sFirst = null; try { for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { //Confusion matrix: //actual 0 20 3 //actual 1 10 5 Evaluation evaluation = new Evaluation(Arrays.asList("class0", "class1")); INDArray predicted0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype); INDArray predicted1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype); INDArray actual0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype); INDArray actual1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype); for (int i = 0; i < 20; i++) { evaluation.eval(actual0, predicted0); } for (int i = 0; i < 3; i++) { evaluation.eval(actual0, predicted1); } for (int i = 0; i < 10; i++) { evaluation.eval(actual1, predicted0); } for (int i = 0; i < 5; i++) { evaluation.eval(actual1, predicted1); } assertEquals(20, evaluation.truePositives().get(0), 0); assertEquals(3, evaluation.falseNegatives().get(0), 0); assertEquals(10, evaluation.falsePositives().get(0), 0); assertEquals(5, evaluation.trueNegatives().get(0), 0); assertEquals((20.0 + 5) / (20 + 3 + 10 + 5), evaluation.accuracy(), 1e-6); String s = evaluation.stats(); if(first == null) { first = evaluation; sFirst = s; } else { assertEquals(first, evaluation); assertEquals(sFirst, s); } } } } finally { Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); } }
Example 17
Source File: CudaIntDataBuffer.java From deeplearning4j with Apache License 2.0 | 4 votes |
/** * Initialize the opType of this buffer */ @Override protected void initTypeAndSize() { elementSize = 4; type = DataType.INT; }
Example 18
Source File: EvaluationBinaryTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testEvaluationBinary() { //Compare EvaluationBinary to Evaluation class DataType dtypeBefore = Nd4j.defaultFloatingPointType(); EvaluationBinary first = null; String sFirst = null; try { for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { Nd4j.getRandom().setSeed(12345); int nExamples = 50; int nOut = 4; long[] shape = {nExamples, nOut}; INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(lpDtype, shape), 0.5)); INDArray predicted = Nd4j.rand(lpDtype, shape); INDArray binaryPredicted = predicted.gt(0.5); EvaluationBinary eb = new EvaluationBinary(); eb.eval(labels, predicted); //System.out.println(eb.stats()); double eps = 1e-6; for (int i = 0; i < nOut; i++) { INDArray lCol = labels.getColumn(i,true); INDArray pCol = predicted.getColumn(i,true); INDArray bpCol = binaryPredicted.getColumn(i,true); int countCorrect = 0; int tpCount = 0; int tnCount = 0; for (int j = 0; j < lCol.length(); j++) { if (lCol.getDouble(j) == bpCol.getDouble(j)) { countCorrect++; if (lCol.getDouble(j) == 1) { tpCount++; } else { tnCount++; } } } double acc = countCorrect / (double) lCol.length(); Evaluation e = new Evaluation(); e.eval(lCol, pCol); assertEquals(acc, eb.accuracy(i), eps); assertEquals(e.accuracy(), eb.scoreForMetric(ACCURACY, i), eps); assertEquals(e.precision(1), eb.scoreForMetric(PRECISION, i), eps); assertEquals(e.recall(1), eb.scoreForMetric(RECALL, i), eps); assertEquals(e.f1(1), eb.scoreForMetric(F1, i), eps); assertEquals(e.falseAlarmRate(), eb.scoreForMetric(FAR, i), eps); assertEquals(e.falsePositiveRate(1), eb.falsePositiveRate(i), eps); assertEquals(tpCount, eb.truePositives(i)); assertEquals(tnCount, eb.trueNegatives(i)); assertEquals((int) e.truePositives().get(1), eb.truePositives(i)); assertEquals((int) e.trueNegatives().get(1), eb.trueNegatives(i)); assertEquals((int) e.falsePositives().get(1), eb.falsePositives(i)); assertEquals((int) e.falseNegatives().get(1), eb.falseNegatives(i)); assertEquals(nExamples, eb.totalCount(i)); String s = eb.stats(); if(first == null) { first = eb; sFirst = s; } else { assertEquals(first, eb); assertEquals(sFirst, s); } } } } } finally { Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); } }
Example 19
Source File: ROCBinaryTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testROCBinary() { //Compare ROCBinary to ROC class DataType dtypeBefore = Nd4j.defaultFloatingPointType(); ROCBinary first30 = null; ROCBinary first0 = null; String sFirst30 = null; String sFirst0 = null; try { for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) { // for (DataType globalDtype : new DataType[]{DataType.HALF}) { Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE); for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { String msg = "globalDtype=" + globalDtype + ", labelPredictionsDtype=" + lpDtype; int nExamples = 50; int nOut = 4; long[] shape = {nExamples, nOut}; for (int thresholdSteps : new int[]{30, 0}) { //0 == exact Nd4j.getRandom().setSeed(12345); INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE, shape), 0.5)).castTo(lpDtype); Nd4j.getRandom().setSeed(12345); INDArray predicted = Nd4j.rand(DataType.DOUBLE, shape).castTo(lpDtype); ROCBinary rb = new ROCBinary(thresholdSteps); for (int xe = 0; xe < 2; xe++) { rb.eval(labels, predicted); //System.out.println(rb.stats()); double eps = lpDtype == DataType.HALF ? 1e-2 : 1e-6; for (int i = 0; i < nOut; i++) { INDArray lCol = labels.getColumn(i, true); INDArray pCol = predicted.getColumn(i, true); ROC r = new ROC(thresholdSteps); r.eval(lCol, pCol); double aucExp = r.calculateAUC(); double auc = rb.calculateAUC(i); assertEquals(msg, aucExp, auc, eps); long apExp = r.getCountActualPositive(); long ap = rb.getCountActualPositive(i); assertEquals(msg, ap, apExp); long anExp = r.getCountActualNegative(); long an = rb.getCountActualNegative(i); assertEquals(anExp, an); PrecisionRecallCurve pExp = r.getPrecisionRecallCurve(); PrecisionRecallCurve p = rb.getPrecisionRecallCurve(i); assertEquals(msg, pExp, p); } String s = rb.stats(); if(thresholdSteps == 0){ if(first0 == null) { first0 = rb; sFirst0 = s; } else if(lpDtype != DataType.HALF) { //Precision issues with FP16 assertEquals(msg, sFirst0, s); assertEquals(first0, rb); } } else { if(first30 == null) { first30 = rb; sFirst30 = s; } else if(lpDtype != DataType.HALF) { //Precision issues with FP16 assertEquals(msg, sFirst30, s); assertEquals(first30, rb); } } // rb.reset(); rb = new ROCBinary(thresholdSteps); } } } } } finally { Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore); } }
Example 20
Source File: BaseCudaDataBuffer.java From deeplearning4j with Apache License 2.0 | 4 votes |
public void actualizePointerAndIndexer() { val cptr = ptrDataBuffer.primaryBuffer(); // skip update if pointers are equal if (cptr != null && pointer != null && cptr.address() == pointer.address()) return; val t = dataType(); if (t == DataType.BOOL) { pointer = new PagedPointer(cptr, length).asBoolPointer(); setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); } else if (t == DataType.UBYTE) { pointer = new PagedPointer(cptr, length).asBytePointer(); setIndexer(UByteIndexer.create((BytePointer) pointer)); } else if (t == DataType.BYTE) { pointer = new PagedPointer(cptr, length).asBytePointer(); setIndexer(ByteIndexer.create((BytePointer) pointer)); } else if (t == DataType.UINT16) { pointer = new PagedPointer(cptr, length).asShortPointer(); setIndexer(UShortIndexer.create((ShortPointer) pointer)); } else if (t == DataType.SHORT) { pointer = new PagedPointer(cptr, length).asShortPointer(); setIndexer(ShortIndexer.create((ShortPointer) pointer)); } else if (t == DataType.UINT32) { pointer = new PagedPointer(cptr, length).asIntPointer(); setIndexer(UIntIndexer.create((IntPointer) pointer)); } else if (t == DataType.INT) { pointer = new PagedPointer(cptr, length).asIntPointer(); setIndexer(IntIndexer.create((IntPointer) pointer)); } else if (t == DataType.UINT64) { pointer = new PagedPointer(cptr, length).asLongPointer(); setIndexer(LongIndexer.create((LongPointer) pointer)); } else if (t == DataType.LONG) { pointer = new PagedPointer(cptr, length).asLongPointer(); setIndexer(LongIndexer.create((LongPointer) pointer)); } else if (t == DataType.BFLOAT16) { pointer = new PagedPointer(cptr, length).asShortPointer(); setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); } else if (t == DataType.HALF) { pointer = new PagedPointer(cptr, length).asShortPointer(); setIndexer(HalfIndexer.create((ShortPointer) pointer)); } else if (t == DataType.FLOAT) { pointer = new PagedPointer(cptr, length).asFloatPointer(); setIndexer(FloatIndexer.create((FloatPointer) pointer)); } else if (t == DataType.DOUBLE) { pointer = new PagedPointer(cptr, length).asDoublePointer(); setIndexer(DoubleIndexer.create((DoublePointer) pointer)); } else if (t == DataType.UTF8) { pointer = new PagedPointer(cptr, length()).asBytePointer(); setIndexer(ByteIndexer.create((BytePointer) pointer)); } else throw new IllegalArgumentException("Unknown datatype: " + dataType()); }