Java Code Examples for org.nd4j.autodiff.samediff.SDVariable#setArray()
The following examples show how to use
org.nd4j.autodiff.samediff.SDVariable#setArray() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: LayerOpValidation.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testLrn2d() { Nd4j.getRandom().setSeed(12345); int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}}; List<String> failed = new ArrayList<>(); for (int[] inSizeNCHW : inputSizes) { SameDiff sd = SameDiff.create(); SDVariable in = null; int[] inSize; //LRN String msg = "LRN with NCHW - input" + Arrays.toString(inSizeNCHW); inSize = inSizeNCHW; in = sd.var("in", inSize); SDVariable out = sd.cnn().localResponseNormalization(in, LocalResponseNormalizationConfig.builder() .depth(3) .bias(1) .alpha(1) .beta(0.5) .build()); INDArray inArr = Nd4j.rand(inSize).muli(10); in.setArray(inArr); SDVariable loss = sd.mean("loss", out); log.info("Starting test: " + msg); TestCase tc = new TestCase(sd).gradientCheck(true); String error = OpValidation.validate(tc); if (error != null) { failed.add(msg); } } assertEquals(failed.toString(), 0, failed.size()); }
Example 2
Source File: LayerOpValidation.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testConv3d() { //Pooling3d, Conv3D, batch norm Nd4j.getRandom().setSeed(12345); //NCDHW format int[][] inputSizes = new int[][]{{2, 3, 4, 5, 5}}; List<String> failed = new ArrayList<>(); for (int[] inSizeNCDHW : inputSizes) { for (boolean ncdhw : new boolean[]{true, false}) { int nIn = inSizeNCDHW[1]; int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW)); for (int i = 0; i < 5; i++) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", shape); SDVariable out; String msg; switch (i) { case 0: //Conv3d, with bias, same msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder() .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) .isSameMode(true) .kH(2).kW(2).kD(2) .sD(1).sH(1).sW(1) .build()); break; case 1: //Conv3d, no bias, no same msg = "1 - conv3d+no bias+no same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); SDVariable w1 = sd.var("w1", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] out = sd.cnn().conv3d(in, w1, Conv3DConfig.builder() .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) .isSameMode(false) .kH(2).kW(2).kD(2) .sD(1).sH(1).sW(1) .build()); break; case 2: //pooling3d - average, no same msg = "2 - pooling 3d, average, same"; out = sd.cnn().avgPooling3d(in, Pooling3DConfig.builder() .kH(2).kW(2).kD(2) .sH(1).sW(1).sD(1) .isSameMode(false) .isNCDHW(ncdhw) .build()); break; case 3: //pooling 3d - max, no same msg = "3 - pooling 3d, max, same"; out = sd.cnn().maxPooling3d(in, Pooling3DConfig.builder() .kH(2).kW(2).kD(2) .sH(1).sW(1).sD(1) .isSameMode(true) .isNCDHW(ncdhw) .build()); break; case 4: //Deconv3d msg = "4 - deconv3d, ncdhw=" + ncdhw; SDVariable wDeconv = sd.var(Nd4j.rand(new int[]{2, 2, 2, 3, nIn})); //[kD, kH, kW, oC, iC] SDVariable bDeconv = sd.var(Nd4j.rand(new int[]{3})); out = sd.cnn().deconv3d("Deconv3d", in, wDeconv, bDeconv, DeConv3DConfig.builder() .kD(2).kH(2).kW(2) .isSameMode(true) .dataFormat(ncdhw ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC) .build()); break; case 5: //Batch norm - 3d input throw new RuntimeException("Batch norm test not yet implemented"); default: throw new RuntimeException(); } INDArray inArr = Nd4j.rand(shape).muli(10); in.setArray(inArr); SDVariable loss = sd.standardDeviation("loss", out, true); log.info("Starting test: " + msg); TestCase tc = new TestCase(sd).gradientCheck(true); tc.testName(msg); String error = OpValidation.validate(tc); if (error != null) { failed.add(name); } } } } assertEquals(failed.toString(), 0, failed.size()); }
Example 3
Source File: MiscOpValidation.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testScatterOpGradients() { List<String> failed = new ArrayList<>(); for (int i = 0; i < 7; i++) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", DataType.DOUBLE, 20, 10); SDVariable indices = sd.var("indices", DataType.INT, new long[]{5}); SDVariable updates = sd.var("updates", DataType.DOUBLE, 5, 10); in.setArray(Nd4j.rand(DataType.DOUBLE, 20, 10)); indices.setArray(Nd4j.create(new double[]{3, 4, 5, 10, 18}).castTo(DataType.INT)); updates.setArray(Nd4j.rand(DataType.DOUBLE, 5, 10).muli(2).subi(1)); SDVariable scatter; String name; switch (i) { case 0: scatter = sd.scatterAdd("s", in, indices, updates); name = "scatterAdd"; break; case 1: scatter = sd.scatterSub("s", in, indices, updates); name = "scatterSub"; break; case 2: scatter = sd.scatterMul("s", in, indices, updates); name = "scatterMul"; break; case 3: scatter = sd.scatterDiv("s", in, indices, updates); name = "scatterDiv"; break; case 4: scatter = sd.scatterUpdate("s", in, indices, updates); name = "scatterUpdate"; break; case 5: scatter = sd.scatterMax("s", in, indices, updates); name = "scatterMax"; break; case 6: scatter = sd.scatterMin("s", in, indices, updates); name = "scatterMin"; break; default: throw new RuntimeException(); } INDArray exp = in.getArr().dup(); int[] indicesInt = indices.getArr().dup().data().asInt(); for( int j=0; j<indicesInt.length; j++ ){ INDArray updateRow = updates.getArr().getRow(j); INDArray destinationRow = exp.getRow(indicesInt[j]); switch (i){ case 0: destinationRow.addi(updateRow); break; case 1: destinationRow.subi(updateRow); break; case 2: destinationRow.muli(updateRow); break; case 3: destinationRow.divi(updateRow); break; case 4: destinationRow.assign(updateRow); break; case 5: destinationRow.assign(Transforms.max(destinationRow, updateRow, true)); break; case 6: destinationRow.assign(Transforms.min(destinationRow, updateRow, true)); break; default: throw new RuntimeException(); } } SDVariable loss = sd.sum(scatter); //.standardDeviation(scatter, true); //.sum(scatter); //TODO stdev might be better here as gradients are non-symmetrical... TestCase tc = new TestCase(sd) .expected(scatter, exp) .gradCheckSkipVariables(indices.name()); String error = OpValidation.validate(tc); if(error != null){ failed.add(name); } } assertEquals(failed.toString(), 0, failed.size()); }
Example 4
Source File: TransformOpValidation.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testIsX() { List<String> failed = new ArrayList<>(); for (int i = 0; i < 4; i++) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", 4); SDVariable out; INDArray exp; INDArray inArr; switch (i) { case 0: inArr = Nd4j.create(new double[]{10, Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); exp = Nd4j.create(new boolean[]{true, false, true, false}); out = sd.math().isFinite(in); break; case 1: inArr = Nd4j.create(new double[]{10, Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); exp = Nd4j.create(new boolean[]{false, true, false, true}); out = sd.math().isInfinite(in); break; case 2: //TODO: IsMax supports both bool and float out: https://github.com/deeplearning4j/deeplearning4j/issues/6872 inArr = Nd4j.create(new double[]{-3, 5, 0, 2}); exp = Nd4j.create(new boolean[]{false, true, false, false}); out = sd.math().isMax(in); break; case 3: inArr = Nd4j.create(new double[]{0, Double.NaN, 10, Double.NaN}); exp = Nd4j.create(new boolean[]{false, true, false, true}); out = sd.math().isNaN(in); break; default: throw new RuntimeException(); } SDVariable other = sd.var("other", Nd4j.rand(DataType.DOUBLE, 4)); SDVariable loss = out.castTo(DataType.DOUBLE).add(other).mean(); TestCase tc = new TestCase(sd) .gradientCheck(false) //Can't gradient check - in -> boolean -> cast(double) .expected(out, exp); in.setArray(inArr); String err = OpValidation.validate(tc, true); if (err != null) { failed.add(err); } } assertEquals(failed.toString(), 0, failed.size()); }