Java Code Examples for org.nd4j.linalg.api.shape.Shape#getReducedShape()
The following examples show how to use
org.nd4j.linalg.api.shape.Shape#getReducedShape() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Variance.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) { INDArray x = oc != null ? oc.getInputArray(0) : x(); if(oc == null && args().length < 1) { throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found."); } long[] argShape = arg().getShape(); if (argShape == null && x == null) { return Collections.emptyList(); } long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x.shape() : argShape); val ret = new ArrayList<LongShapeDescriptor>(1); val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims()); ret.add(LongShapeDescriptor.fromShape(reducedShape, resultType())); return ret; }
Example 2
Source File: StandardDeviation.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public List<LongShapeDescriptor> calculateOutputShape() { if(args().length < 1) { throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found."); } long[] argShape = arg().getShape(); if (argShape == null && x() == null) { return Collections.emptyList(); } long[] inputShape = (argShape == null || Shape.isPlaceholderShape(argShape) ? x().shape() : argShape); val ret = new ArrayList<LongShapeDescriptor>(1); val reducedShape = Shape.getReducedShape(inputShape,dimensions, isKeepDims()); ret.add(LongShapeDescriptor.fromShape(reducedShape, resultType())); return ret; }
Example 3
Source File: ShapeTestC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_2_T() { val shape = new int[]{5, 5, 5}; val axis = new int[]{1, 0, 1}; val result = Shape.getReducedShape(shape, axis, true, true); assertArrayEquals(new long[]{1, 1, 5}, result); }
Example 4
Source File: BaseReduceBoolOp.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) { INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, DataType.BOOL)); }
Example 5
Source File: BaseReduceFloatOp.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) { INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); DataType retType = arg().dataType(); if(!retType.isFPType()) retType = Nd4j.defaultFloatingPointType(); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, retType)); }
Example 6
Source File: BaseReduceSameOp.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) { INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); DataType rt = oc != null ? resultType(oc) : resultType(); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, rt)); }
Example 7
Source File: BaseIndexAccumulation.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public List<LongShapeDescriptor> calculateOutputShape(OpContext oc){ INDArray x = oc != null ? oc.getInputArray(0) : x(); if(x == null) return Collections.emptyList(); long[] reducedShape = Shape.getReducedShape(x.shape(), dimensions, keepDims); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, DataType.LONG)); }
Example 8
Source File: ShapeTestC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_4_F() { val shape = new int[]{4, 4}; val axis = new int[]{0, 0}; val result = Shape.getReducedShape(shape, axis, false, true); log.info("Result: {}", result); assertArrayEquals(new long[]{4}, result); }
Example 9
Source File: ShapeTestC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_3_F() { val shape = new int[]{1, 1}; val axis = new int[]{0, 0}; val result = Shape.getReducedShape(shape, axis, false, true); log.info("Result: {}", result); assertArrayEquals(new long[]{1}, result); }
Example 10
Source File: ShapeTestC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_3_T() { val shape = new int[]{1, 1}; val axis = new int[]{1, 0, 1}; val result = Shape.getReducedShape(shape, axis, true, true); assertArrayEquals(new long[]{1, 1}, result); }
Example 11
Source File: ShapeTestC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_2_F() { val shape = new int[]{5, 5, 5}; val axis = new int[]{0, 0, 1}; val result = Shape.getReducedShape(shape, axis, false, true); assertArrayEquals(new long[]{5}, result); }
Example 12
Source File: ShapeTestC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_1_T() throws Exception { val shape = new int[]{5, 5}; val axis = new int[]{1, 0, 1}; val result = Shape.getReducedShape(shape, axis, true, true); assertArrayEquals(new long[]{1, 1}, result); }
Example 13
Source File: ShapeTestC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_1_F() { val shape = new int[]{5, 5}; val axis = new int[]{0, 0, 1}; val result = Shape.getReducedShape(shape, axis, false, true); assertArrayEquals(new long[]{}, result); }
Example 14
Source File: ShapeTestC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_1_T() { val shape = new int[]{5, 5}; val axis = new int[]{1, 0, 1}; val result = Shape.getReducedShape(shape, axis, true, true); assertArrayEquals(new long[]{1, 1}, result); }
Example 15
Source File: BaseAccumulation.java From nd4j with Apache License 2.0 | 5 votes |
@Override public List<long[]> calculateOutputShape() { if(args().length < 1) { throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found."); } if(arg().getShape() == null) return Collections.emptyList(); List<long[]> ret = new ArrayList<>(1); val reducedShape = Shape.getReducedShape(arg().getShape(),dimensions, isKeepDims(), newFormat); ret.add(reducedShape); return ret; }
Example 16
Source File: ShapeTestC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_3_F() throws Exception { val shape = new int[]{1, 1}; val axis = new int[]{0, 0}; val result = Shape.getReducedShape(shape, axis, false, true); log.info("Result: {}", result); assertArrayEquals(new long[]{1}, result); }
Example 17
Source File: ShapeTestC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_3_T() throws Exception { val shape = new int[]{1, 1}; val axis = new int[]{1, 0, 1}; val result = Shape.getReducedShape(shape, axis, true, true); assertArrayEquals(new long[]{1, 1}, result); }
Example 18
Source File: ShapeTestC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_2_F() throws Exception { val shape = new int[]{5, 5, 5}; val axis = new int[]{0, 0, 1}; val result = Shape.getReducedShape(shape, axis, false, true); assertArrayEquals(new long[]{5}, result); }
Example 19
Source File: ShapeTestC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_2_T() throws Exception { val shape = new int[]{5, 5, 5}; val axis = new int[]{1, 0, 1}; val result = Shape.getReducedShape(shape, axis, true, true); assertArrayEquals(new long[]{1, 1, 5}, result); }
Example 20
Source File: ShapeTestC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testKeepDimsShape_1_F() throws Exception { val shape = new int[]{5, 5}; val axis = new int[]{0, 0, 1}; val result = Shape.getReducedShape(shape, axis, false, true); assertArrayEquals(new long[]{}, result); }