Java Code Examples for org.nd4j.linalg.api.buffer.DataType#BOOL
The following examples show how to use
org.nd4j.linalg.api.buffer.DataType#BOOL .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ArrayOptionsHelper.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static DataType dataType(long opt) { if (hasBitSet(opt, DTYPE_COMPRESSED_BIT)) return DataType.COMPRESSED; else if (hasBitSet(opt, DTYPE_HALF_BIT)) return DataType.HALF; else if (hasBitSet(opt, DTYPE_BFLOAT16_BIT)) return DataType.BFLOAT16; else if (hasBitSet(opt, DTYPE_FLOAT_BIT)) return DataType.FLOAT; else if (hasBitSet(opt, DTYPE_DOUBLE_BIT)) return DataType.DOUBLE; else if (hasBitSet(opt, DTYPE_INT_BIT)) return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UINT32 : DataType.INT; else if (hasBitSet(opt, DTYPE_LONG_BIT)) return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UINT64 : DataType.LONG; else if (hasBitSet(opt, DTYPE_BOOL_BIT)) return DataType.BOOL; else if (hasBitSet(opt, DTYPE_BYTE_BIT)) { return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UBYTE : DataType.BYTE; //Byte bit set for both UBYTE and BYTE } else if (hasBitSet(opt, DTYPE_SHORT_BIT)) return hasBitSet(opt, DTYPE_UNSIGNED_BIT) ? DataType.UINT16 : DataType.SHORT; else if (hasBitSet(opt, DTYPE_UTF8_BIT)) return DataType.UTF8; else throw new ND4JUnknownDataTypeException("Unknown extras set: [" + opt + "]"); }
Example 2
Source File: SpecialTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void reproduceWorkspaceCrash_4(){ val conf = WorkspaceConfiguration.builder().build(); val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS"); val dtypes = new DataType[]{DataType.LONG, DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.SHORT, DataType.BYTE, DataType.UBYTE, DataType.BOOL}; for (val dX : dtypes) { for (val dZ: dtypes) { try(val ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS")) { val array = Nd4j.create(dX, 100, 100).assign(1); // log.info("Trying to cast {} to {}", dX, dZ); val casted = array.castTo(dZ); val exp = Nd4j.create(dZ, 100, 100).assign(1); assertEquals(exp, casted); } } } }
Example 3
Source File: SpecialTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void reproduceWorkspaceCrash_3(){ val conf = WorkspaceConfiguration.builder().build(); val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS"); val dtypes = new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, DataType.UBYTE, DataType.BOOL}; for (val dX : dtypes) { for (val dZ: dtypes) { try(val ws2 = ws.notifyScopeEntered()) { val array = Nd4j.create(dX, 2, 5).assign(1); // log.info("Trying to cast {} to {}", dX, dZ); val casted = array.castTo(dZ); val exp = Nd4j.create(dZ, 2, 5).assign(1); assertEquals(exp, casted); Nd4j.getExecutioner().commit(); } } } }
Example 4
Source File: PythonNumpyBasicTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testInplaceExecution(){ if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return; if (shape.length == 0) return; List<PythonVariable> inputs = new ArrayList<>(); INDArray x = Nd4j.ones(dataType, shape); INDArray y = Nd4j.zeros(dataType, shape); INDArray z = x.mul(y.add(2)); // Nd4j.getAffinityManager().ensureLocation(z, AffinityManager.Location.HOST); PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray"); inputs.add(new PythonVariable<>("x", arrType, x)); inputs.add(new PythonVariable<>("y", arrType, y)); List<PythonVariable> outputs = new ArrayList<>(); PythonVariable<INDArray> output = new PythonVariable<>("x", arrType); outputs.add(output); String code = "x *= y + 2"; PythonExecutioner.exec(code, inputs, outputs); INDArray z2 = output.getValue(); Assert.assertEquals(x.dataType(), z2.dataType()); Assert.assertEquals(z.dataType(), z2.dataType()); Assert.assertEquals(x, z2); Assert.assertEquals(z, z2); Assert.assertEquals(x.data().pointer().address(), z2.data().pointer().address()); if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ Assert.assertEquals(getDeviceAddress(x), getDeviceAddress(z2)); } }
Example 5
Source File: PythonNumpyJobTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testNumpyJobBasic(){ PythonContextManager.deleteNonMainContexts(); List<PythonVariable> inputs = new ArrayList<>(); INDArray x = Nd4j.ones(dataType, 2, 3); INDArray y = Nd4j.zeros(dataType, 2, 3); INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z; PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray"); inputs.add(new PythonVariable<>("x", arrType, x)); inputs.add(new PythonVariable<>("y", arrType, y)); List<PythonVariable> outputs = new ArrayList<>(); PythonVariable<INDArray> output = new PythonVariable<>("z", arrType); outputs.add(output); String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; PythonJob job = new PythonJob("job1", code, false); job.exec(inputs, outputs); INDArray z2 = output.getValue(); if (dataType == DataType.BFLOAT16){ z2 = z2.castTo(DataType.FLOAT); } Assert.assertEquals(z, z2); }
Example 6
Source File: NDValidation.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Validate that the operation is being applied on a numerical INDArray (not boolean or utf8). * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays * * @param opName Operation name to print in the exception * @param v Variable to validate datatype for (input to operation) */ public static void validateNumerical(String opName, String inputName, INDArray v) { if (v == null) return; if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type;" + " got array with non-integer data type " + v.dataType()); }
Example 7
Source File: PythonNumpyJobTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testNumpyJobReturnAllVariables(){ PythonContextManager.deleteNonMainContexts(); List<PythonVariable> inputs = new ArrayList<>(); INDArray x = Nd4j.ones(dataType, 2, 3); INDArray y = Nd4j.zeros(dataType, 2, 3); INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray"); inputs.add(new PythonVariable<>("x", arrType, x)); inputs.add(new PythonVariable<>("y", arrType, y)); String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; PythonJob job = new PythonJob("job1", code, false); List<PythonVariable> outputs = job.execAndReturnAllVariables(inputs); INDArray x2 = (INDArray) outputs.get(0).getValue(); INDArray y2 = (INDArray) outputs.get(1).getValue(); INDArray z2 = (INDArray) outputs.get(2).getValue(); if (dataType == DataType.BFLOAT16){ x = x.castTo(DataType.FLOAT); y = y.castTo(DataType.FLOAT); z = z.castTo(DataType.FLOAT); } Assert.assertEquals(x, x2); Assert.assertEquals(y, y2); Assert.assertEquals(z, z2); }
Example 8
Source File: NDValidation.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Validate that the operation is being applied on numerical INDArrays (not boolean or utf8). * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays * * @param opName Operation name to print in the exception * @param v Variable to perform operation on */ public static void validateNumerical(String opName, INDArray[] v) { if (v == null) return; for (int i = 0; i < v.length; i++) { if (v[i].dataType() == DataType.BOOL || v[i].dataType() == DataType.UTF8) throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to input array " + i + " with non-numerical data type " + v[i].dataType()); } }
Example 9
Source File: PythonNumpyBasicTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}, shape={2}") public static Collection params() { DataType[] types = new DataType[] { DataType.BOOL, DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE, DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64 }; long[][] shapes = new long[][]{ new long[]{2, 3}, new long[]{3}, new long[]{1}, new long[]{} // scalar }; List<Object[]> ret = new ArrayList<>(); for (DataType type: types){ for (long[] shape: shapes){ ret.add(new Object[]{type, shape, Arrays.toString(shape)}); } } return ret; }
Example 10
Source File: NDValidation.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Validate that the operation is being applied on a boolean type INDArray * * @param opName Operation name to print in the exception * @param v Variable to validate datatype for (input to operation) */ public static void validateBool(String opName, INDArray v) { if (v == null) return; if (v.dataType() != DataType.BOOL) throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-boolean point data type " + v.dataType()); }
Example 11
Source File: ND4JUtil.java From konduit-serving with Apache License 2.0 | 5 votes |
public static DataType typeNDArrayTypeToNd4j(@NonNull NDArrayType type){ switch (type){ case DOUBLE: return DataType.DOUBLE; case FLOAT: return DataType.FLOAT; case FLOAT16: return DataType.FLOAT16; case BFLOAT16: return DataType.BFLOAT16; case INT64: return DataType.INT64; case INT32: return DataType.INT32; case INT16: return DataType.INT16; case INT8: return DataType.INT8; case UINT64: return DataType.UINT64; case UINT32: return DataType.UINT32; case UINT16: return DataType.UINT16; case UINT8: return DataType.UINT8; case BOOL: return DataType.BOOL; case UTF8: return DataType.UTF8; default: throw new UnsupportedOperationException("Unable to convert datatype to ND4J datatype: " + type); } }
Example 12
Source File: MaskedReductionUtil.java From deeplearning4j with Apache License 2.0 | 4 votes |
public static INDArray maskedPoolingEpsilonCnn(PoolingType poolingType, INDArray input, INDArray mask, INDArray epsilon2d, int pnorm, DataType dataType) { // [minibatch, channels, h=1, w=X] or [minibatch, channels, h=X, w=1] data // with a mask array of shape [minibatch, X] //If masking along height: broadcast dimensions are [0,2] //If masking along width: broadcast dimensions are [0,3] mask = mask.castTo(dataType); //No-op if correct type //General case: must be equal or 1 on each dimension int[] dimensions = new int[4]; int count = 0; for(int i=0; i<4; i++ ){ if(input.size(i) == mask.size(i)){ dimensions[count++] = i; } } if(count < 4){ dimensions = Arrays.copyOfRange(dimensions, 0, count); } switch (poolingType) { case MAX: //TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op INDArray negInfMask; if(mask.dataType() == DataType.BOOL){ negInfMask = Transforms.not(mask).castTo(dataType); } else { negInfMask = mask.rsub(1.0); } BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0)); INDArray withInf = Nd4j.createUninitialized(dataType, input.shape()); Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, dimensions)); //At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op INDArray isMax = Nd4j.exec(new IsMax(withInf, withInf.ulike(), 2, 3))[0]; return Nd4j.getExecutioner().exec(new BroadcastMulOp(isMax, epsilon2d, isMax, 0, 1)); case AVG: case SUM: //if out = sum(in,dims) then dL/dIn = dL/dOut -> duplicate to each step and mask //if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut //With masking: N differs for different time series INDArray out = Nd4j.createUninitialized(dataType, input.shape(), 'f'); //Broadcast copy op, then divide and mask to 0 as appropriate Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1)); Nd4j.getExecutioner().exec(new BroadcastMulOp(out, mask, out, dimensions)); if (poolingType == PoolingType.SUM) { return out; } //Note that with CNNs, current design is restricted to [minibatch, channels, 1, W] ot [minibatch, channels, H, 1] INDArray nEachTimeSeries = mask.sum(1,2,3); //[minibatchSize,tsLength] -> [minibatchSize,1] Nd4j.getExecutioner().exec(new BroadcastDivOp(out, nEachTimeSeries, out, 0)); return out; case PNORM: //Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0 INDArray masked2 = Nd4j.createUninitialized(dataType, input.shape()); Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, dimensions)); INDArray abs = Transforms.abs(masked2, true); Transforms.pow(abs, pnorm, false); INDArray pNorm = Transforms.pow(abs.sum(2, 3), 1.0 / pnorm); INDArray numerator; if (pnorm == 2) { numerator = input.dup(); } else { INDArray absp2 = Transforms.pow(Transforms.abs(input, true), pnorm - 2, false); numerator = input.mul(absp2); } INDArray denom = Transforms.pow(pNorm, pnorm - 1, false); denom.rdivi(epsilon2d); Nd4j.getExecutioner().execAndReturn(new BroadcastMulOp(numerator, denom, numerator, 0, 1)); Nd4j.getExecutioner().exec(new BroadcastMulOp(numerator, mask, numerator, dimensions)); //Apply mask return numerator; default: throw new UnsupportedOperationException("Unknown or not supported pooling type: " + poolingType); } }
Example 13
Source File: BaseReduceBoolOp.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public DataType resultType(OpContext oc) { return DataType.BOOL; }
Example 14
Source File: BaseReduceBoolOp.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public DataType resultType() { return DataType.BOOL; }
Example 15
Source File: BaseCudaDataBuffer.java From deeplearning4j with Apache License 2.0 | 4 votes |
public void actualizePointerAndIndexer() { val cptr = ptrDataBuffer.primaryBuffer(); // skip update if pointers are equal if (cptr != null && pointer != null && cptr.address() == pointer.address()) return; val t = dataType(); if (t == DataType.BOOL) { pointer = new PagedPointer(cptr, length).asBoolPointer(); setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); } else if (t == DataType.UBYTE) { pointer = new PagedPointer(cptr, length).asBytePointer(); setIndexer(UByteIndexer.create((BytePointer) pointer)); } else if (t == DataType.BYTE) { pointer = new PagedPointer(cptr, length).asBytePointer(); setIndexer(ByteIndexer.create((BytePointer) pointer)); } else if (t == DataType.UINT16) { pointer = new PagedPointer(cptr, length).asShortPointer(); setIndexer(UShortIndexer.create((ShortPointer) pointer)); } else if (t == DataType.SHORT) { pointer = new PagedPointer(cptr, length).asShortPointer(); setIndexer(ShortIndexer.create((ShortPointer) pointer)); } else if (t == DataType.UINT32) { pointer = new PagedPointer(cptr, length).asIntPointer(); setIndexer(UIntIndexer.create((IntPointer) pointer)); } else if (t == DataType.INT) { pointer = new PagedPointer(cptr, length).asIntPointer(); setIndexer(IntIndexer.create((IntPointer) pointer)); } else if (t == DataType.UINT64) { pointer = new PagedPointer(cptr, length).asLongPointer(); setIndexer(LongIndexer.create((LongPointer) pointer)); } else if (t == DataType.LONG) { pointer = new PagedPointer(cptr, length).asLongPointer(); setIndexer(LongIndexer.create((LongPointer) pointer)); } else if (t == DataType.BFLOAT16) { pointer = new PagedPointer(cptr, length).asShortPointer(); setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); } else if (t == DataType.HALF) { pointer = new PagedPointer(cptr, length).asShortPointer(); setIndexer(HalfIndexer.create((ShortPointer) pointer)); } else if (t == DataType.FLOAT) { pointer = new PagedPointer(cptr, length).asFloatPointer(); setIndexer(FloatIndexer.create((FloatPointer) pointer)); } else if (t == DataType.DOUBLE) { pointer = new PagedPointer(cptr, length).asDoublePointer(); setIndexer(DoubleIndexer.create((DoublePointer) pointer)); } else if (t == DataType.UTF8) { pointer = new PagedPointer(cptr, length()).asBytePointer(); setIndexer(ByteIndexer.create((BytePointer) pointer)); } else throw new IllegalArgumentException("Unknown datatype: " + dataType()); }
Example 16
Source File: ExecDebuggingListener.java From deeplearning4j with Apache License 2.0 | 4 votes |
private static String createString(INDArray arr){ StringBuilder sb = new StringBuilder(); if(arr.isEmpty()){ sb.append("Nd4j.empty(DataType.").append(arr.dataType()).append(");"); } else { sb.append("Nd4j.createFromArray("); DataType dt = arr.dataType(); switch (dt){ case DOUBLE: double[] dArr = arr.dup().data().asDouble(); sb.append(Arrays.toString(dArr).replaceAll("[\\[\\]]", "")); break; case FLOAT: case HALF: case BFLOAT16: float[] fArr = arr.dup().data().asFloat(); sb.append(Arrays.toString(fArr) .replaceAll(",", "f,") .replaceAll("]", "f") .replaceAll("[\\[\\]]", "")); break; case LONG: case UINT32: case UINT64: long[] lArr = arr.dup().data().asLong(); sb.append(Arrays.toString(lArr) .replaceAll(",", "L,") .replaceAll("]", "L") .replaceAll("[\\[\\]]", "")); break; case INT: case SHORT: case UBYTE: case BYTE: case UINT16: case BOOL: int[] iArr = arr.dup().data().asInt(); sb.append(Arrays.toString(iArr).replaceAll("[\\[\\]]", "")); break; case UTF8: break; case COMPRESSED: case UNKNOWN: break; } sb.append(").reshape(").append(Arrays.toString(arr.shape()).replaceAll("[\\[\\]]", "")) .append(")"); if(dt == DataType.HALF || dt == DataType.BFLOAT16 || dt == DataType.UINT32 || dt == DataType.UINT64 || dt == DataType.SHORT || dt == DataType.UBYTE || dt == DataType.BYTE || dt == DataType.UINT16 || dt == DataType.BOOL){ sb.append(".cast(DataType.").append(arr.dataType()).append(")"); } } return sb.toString(); }
Example 17
Source File: ArraySavingListener.java From deeplearning4j with Apache License 2.0 | 4 votes |
public static void compare(File dir1, File dir2, double eps) throws Exception { File[] files1 = dir1.listFiles(); File[] files2 = dir2.listFiles(); Preconditions.checkNotNull(files1, "No files in directory 1: %s", dir1); Preconditions.checkNotNull(files2, "No files in directory 2: %s", dir2); Preconditions.checkState(files1.length == files2.length, "Different number of files: %s vs %s", files1.length, files2.length); Map<String,File> m1 = toMap(files1); Map<String,File> m2 = toMap(files2); for(File f : files1){ String name = f.getName(); String varName = name.substring(name.indexOf('_') + 1, name.length()-4); //Strip "x_" and ".bin" File f2 = m2.get(varName); INDArray arr1 = Nd4j.readBinary(f); INDArray arr2 = Nd4j.readBinary(f2); //TODO String arrays won't work here! boolean eq = arr1.equalsWithEps(arr2, eps); if(eq){ System.out.println("Equals: " + varName.replaceAll("__", "/")); } else { if(arr1.dataType() == DataType.BOOL){ INDArray xor = Nd4j.exec(new Xor(arr1, arr2)); int count = xor.castTo(DataType.INT).sumNumber().intValue(); System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count); System.out.println("\t" + f.getAbsolutePath()); System.out.println("\t" + f2.getAbsolutePath()); xor.close(); } else { INDArray sub = arr1.sub(arr2); INDArray diff = Nd4j.math.abs(sub); double maxDiff = diff.maxNumber().doubleValue(); System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff); System.out.println("\t" + f.getAbsolutePath()); System.out.println("\t" + f2.getAbsolutePath()); sub.close(); diff.close(); } } arr1.close(); arr2.close(); } }
Example 18
Source File: NDValidation.java From deeplearning4j with Apache License 2.0 | 3 votes |
/** * Validate that the operation is being applied on a numerical INDArray (not boolean or utf8). * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays * * @param opName Operation name to print in the exception * @param v Variable to perform operation on */ public static void validateNumerical(String opName, INDArray v) { if (v == null) return; if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-numerical data type " + v.dataType()); }
Example 19
Source File: SDValidation.java From deeplearning4j with Apache License 2.0 | 2 votes |
/** * Validate that the operation is being applied on boolean SDVariables * * @param opName Operation name to print in the exception * @param v1 Variable to validate datatype for (input to operation) * @param v2 Variable to validate datatype for (input to operation) */ protected static void validateBool(String opName, SDVariable v1, SDVariable v2) { if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL) throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.name() + "\" and \"" + v2.name() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType()); }
Example 20
Source File: PythonNumpyJobTest.java From deeplearning4j with Apache License 2.0 | 2 votes |
@Test public void testMultipleNumpyJobsSetupRunParallel(){ if (dataType == DataType.BOOL)return; PythonContextManager.deleteNonMainContexts(); String code1 = "five=None\n" + "def setup():\n" + " global five\n"+ " five = 5\n\n" + "def run(a, b):\n" + " c = a + b + five\n"+ " return {'c':c}\n\n"; PythonJob job1 = new PythonJob("job1", code1, true); String code2 = "five=None\n" + "def setup():\n" + " global five\n"+ " five = 5\n\n" + "def run(a, b):\n" + " c = a + b - five\n"+ " return {'c':c}\n\n"; PythonJob job2 = new PythonJob("job2", code2, true); List<PythonVariable> inputs = new ArrayList<>(); inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); List<PythonVariable> outputs = new ArrayList<>(); outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); job1.exec(inputs, outputs); assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), outputs.get(0).getValue()); job2.exec(inputs, outputs); assertEquals(Nd4j.zeros((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3), outputs.get(0).getValue()); inputs = new ArrayList<>(); inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); outputs = new ArrayList<>(); outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); job1.exec(inputs, outputs); assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), outputs.get(0).getValue()); job2.exec(inputs, outputs); assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(2), outputs.get(0).getValue()); }