Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#stridedSlice()
The following examples show how to use
org.nd4j.autodiff.samediff.SameDiff#stridedSlice() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ShapeOpValidation.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testStridedSlice2dBasic() { INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable slice_full = sd.stridedSlice(in,new long[]{0, 0},new long[]{3, 4},new long[]{1, 1}); SDVariable subPart = sd.stridedSlice(in,new long[]{1, 2},new long[]{3, 4},new long[]{1, 1}); // SDVariable subPart2 = sd.stridedSlice(in,new long[]{0, 0},new long[]{4, 5},new long[]{2, 2}); sd.outputAll(null); assertEquals(inArr, slice_full.getArr()); assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr()); // assertEquals(inArr.get(interval(0, 2, 4), interval(0, 2, 5)), subPart2.getArr()); }
Example 2
Source File: ShapeOpValidation.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testStridedSliceEllipsisMask() { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); //[1:3,...] -> [1:3,:,:] SDVariable slice = sd.stridedSlice(in,new long[]{1},new long[]{3},new long[]{1}, 0, 0, 1 << 1, 0, 0); //[1:3,...,1:4] -> [1:3,:,1:4] SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 1},new long[]{3, 4},new long[]{1, 1}, 0, 0, 1 << 1, 0, 0); sd.outputAll(Collections.emptyMap()); assertEquals(inArr.get(interval(1, 3), all(), all()), slice.getArr()); assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr()); }
Example 3
Source File: ShapeOpValidation.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testStridedSliceShrinkAxisMask() { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable slice = sd.stridedSlice(in,new long[]{0, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1); SDVariable slice2 = sd.stridedSlice(in,new long[]{2, 0, 0},new long[]{-999, 4, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1); SDVariable slice3 = sd.stridedSlice(in,new long[]{1, 2, 1},new long[]{-999, -999, 5},new long[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1); sd.outputAll(null); assertEquals(inArr.get(point(0), all(), all()), slice.getArr()); assertEquals(inArr.get(point(2), all(), all()), slice2.getArr()); assertEquals(inArr.get(point(1), point(2), interval(1, 5)).reshape(4), slice3.getArr()); }
Example 4
Source File: GradCheckMisc.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testStridedSliceGradient() { Nd4j.getRandom().setSeed(12345); //Order here: original shape, begin, size List<SSCase> testCases = new ArrayList<>(); testCases.add(SSCase.builder().shape(3, 4).begin(0, 0).end(3, 4).strides(1, 1).build()); testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(2, 3).strides(1, 1).build()); testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(3, 4).strides(1, 1).beginMask(1).build()); testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(3, -999).strides(1, 1).endMask(1 << 1).build()); testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(-999, 4).strides(1, 1).beginMask(1).endMask(1).build()); testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0, 0).end(-999, 3, 4).strides(1, 1).newAxisMask(1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 4, 5).strides(1, 1, 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2, 3).end(3, 4, 5).strides(1, 1, 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 3, 5).strides(1, 2, 2).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, 4).strides(1, 1, 1).beginMask(1 << 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, -999).strides(1, 1, 1).beginMask(1 << 1).endMask(1 << 2).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2).end(3, 4).strides(1, 1).ellipsisMask(1 << 1).build()); //[1:3,...,2:4] testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1, 2).end(3, -999, 3, 4).strides(1, -999, 1, 2).newAxisMask(1 << 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 0, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 1, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build()); for (int i = 0; i < testCases.size(); i++) { SSCase t = testCases.get(i); INDArray arr = Nd4j.rand(t.getShape()); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", arr); SDVariable slice = sd.stridedSlice(in, t.getBegin(), t.getEnd(), t.getStrides(), t.getBeginMask(), t.getEndMask(), t.getEllipsisMask(), t.getNewAxisMask(), t.getShrinkAxisMask()); SDVariable stdev = sd.standardDeviation(slice, true); String msg = "i=" + i + ": " + t; log.info("Starting test: " + msg); GradCheckUtil.checkGradients(sd); } }
Example 5
Source File: ShapeOpValidation.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testStridedSliceBeginEndMask() { INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable slice1 = sd.stridedSlice(in,new long[]{-999, 0},new long[]{2, 4},new long[]{1, 1}, 1 << 1, 0, 0, 0, 0); SDVariable slice2 = sd.stridedSlice(in,new long[]{1, 0},new long[]{-999, 4},new long[]{1, 1}, 0, 1, 0, 0, 0); sd.outputAll(null); assertEquals(inArr.get(NDArrayIndex.interval(0, 2), NDArrayIndex.all()), slice1.getArr()); assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr()); }
Example 6
Source File: ShapeOpValidation.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testStridedSliceNewAxisMask() { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable slice = sd.stridedSlice(in,new long[]{-999, 0, 0, 0},new long[]{-999, 3, 4, 5},new long[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0); INDArray out = slice.eval(); assertArrayEquals(new long[]{1, 3, 4, 5}, out.shape()); assertEquals(inArr, out.get(point(0), all(), all(), all())); }
Example 7
Source File: ShapeOpValidation.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testStridedSliceNewAxisMask2() { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable slice = sd.stridedSlice(in,new long[]{1, 1, -999, 1},new long[]{3, 3, -999, 4},new long[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0); INDArray out = slice.eval(); assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape()); }
Example 8
Source File: ShapeOpValidation.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testStridedSliceGradient() { Nd4j.getRandom().setSeed(12345); //Order here: original shape, begin, size List<SSCase> testCases = new ArrayList<>(); testCases.add(SSCase.builder().shape(3, 4).begin(0, 0).end(3, 4).strides(1, 1).build()); testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(2, 3).strides(1, 1).build()); testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(3, 4).strides(1, 1).beginMask(1).build()); testCases.add(SSCase.builder().shape(3, 4).begin(1, 1).end(3, -999).strides(1, 1).endMask(1 << 1).build()); testCases.add(SSCase.builder().shape(3, 4).begin(-999, 0).end(-999, 4).strides(1, 1).beginMask(1).endMask(1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 4, 5).strides(1, 1, 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2, 3).end(3, 4, 5).strides(1, 1, 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(0, 0, 0).end(3, 3, 5).strides(1, 2, 2).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, 4).strides(1, 1, 1).beginMask(1 << 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1).end(3, 3, -999).strides(1, 1, 1).beginMask(1 << 1).endMask(1 << 2).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 2).end(3, 4).strides(1, 1).ellipsisMask(1 << 1).build()); //[1:3,...,2:4] testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, -999, 1, 2).end(3, -999, 3, 4).strides(1, -999, 1, 2).newAxisMask(1 << 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 0, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build()); testCases.add(SSCase.builder().shape(3, 4, 5).begin(1, 1, 1).end(3, -999, 4).strides(1, 1, 1).shrinkAxisMask(1 << 1).build()); Map<Integer,INDArrayIndex[]> indices = new HashMap<>(); indices.put(0, new INDArrayIndex[]{all(), all()}); indices.put(1, new INDArrayIndex[]{interval(1,2), interval(1,3)}); indices.put(2, new INDArrayIndex[]{interval(0,3), interval(0,4)}); indices.put(3, new INDArrayIndex[]{interval(1,3), interval(1,4)}); indices.put(5, new INDArrayIndex[]{all(), all(), all()}); indices.put(7, new INDArrayIndex[]{interval(0,1,3), interval(0,2,3), interval(0,2,5)}); List<String> failed = new ArrayList<>(); for (int i = 0; i < testCases.size(); i++) { SSCase t = testCases.get(i); INDArray arr = Nd4j.rand(t.getShape()); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", arr); SDVariable slice = sd.stridedSlice(in, t.getBegin(), t.getEnd(), t.getStrides(), t.getBeginMask(), t.getEndMask(), t.getEllipsisMask(), t.getNewAxisMask(), t.getShrinkAxisMask()); SDVariable stdev = sd.standardDeviation(slice, true); String msg = "i=" + i + ": " + t; log.info("Starting test: " + msg); TestCase tc = new TestCase(sd); tc.testName(msg); if(indices.containsKey(i)){ tc.expected(slice, arr.get(indices.get(i)).dup()); } String error = OpValidation.validate(tc, true); if(error != null){ failed.add(error); } } assertEquals(failed.toString(), 0, failed.size()); }