Java Code Examples for org.nd4j.linalg.util.ArrayUtil#removeIndex()
The following examples show how to use
org.nd4j.linalg.util.ArrayUtil#removeIndex() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Shape.java From nd4j with Apache License 2.0 | 6 votes |
public static long[] getReducedShape(long[] wholeShape, int[] dimensions) { if (isWholeArray(wholeShape, dimensions)) return new long[] {}; else if (dimensions.length == 1 && wholeShape.length == 2) { val ret = new long[2]; if (dimensions[0] == 1) { ret[0] = wholeShape[0]; ret[1] = 1; } else if (dimensions[0] == 0) { ret[0] = 1; ret[1] = wholeShape[1]; } return ret; } return ArrayUtil.removeIndex(wholeShape, dimensions); }
Example 2
Source File: CudaGridExecutioner.java From nd4j with Apache License 2.0 | 5 votes |
protected void buildZ(IndexAccumulation op, int... dimension) { Arrays.sort(dimension); for (int i = 0; i < dimension.length; i++) { if (dimension[i] < 0) dimension[i] += op.x().rank(); } //do op along all dimensions if (dimension.length == op.x().rank()) dimension = new int[] {Integer.MAX_VALUE}; long[] retShape = Shape.wholeArrayDimension(dimension) ? new long[] {1, 1} : ArrayUtil.removeIndex(op.x().shape(), dimension); //ensure vector is proper shape if (retShape.length == 1) { if (dimension[0] == 0) retShape = new long[] {1, retShape[0]}; else retShape = new long[] {retShape[0], 1}; } else if (retShape.length == 0) { retShape = new long[] {1, 1}; } if(op.z() == null || op.z() == op.x()){ INDArray ret = null; if (Math.abs(op.zeroDouble()) < Nd4j.EPS_THRESHOLD) { ret = Nd4j.zeros(retShape); } else { ret = Nd4j.valueArrayOf(retShape, op.zeroDouble()); } op.setZ(ret); } else if(!Arrays.equals(retShape, op.z().shape())){ throw new IllegalStateException("Z array shape does not match expected return type for op " + op + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(op.z().shape())); } }
Example 3
Source File: ArrayUtilsTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testArrayRemoveIndex1() throws Exception { //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8}); int[] arraySource = new int[] {1,2,3,4,5,6,7,8}; int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{0,1}); assertEquals(6, dst.length); assertEquals(3, dst[0]); }
Example 4
Source File: ArrayUtilsTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testArrayRemoveIndex2() throws Exception { //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8}); int[] arraySource = new int[] {1,2,3,4,5,6,7,8}; int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{0,7}); assertEquals(6, dst.length); assertEquals(2, dst[0]); assertEquals(7, dst[5]); }
Example 5
Source File: ArrayUtilsTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testArrayRemoveIndex4() throws Exception { //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8}); int[] arraySource = new int[] {1,2,3,4,5,6,7,8}; int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{0}); assertEquals(7, dst.length); assertEquals(2, dst[0]); assertEquals(8, dst[6]); }
Example 6
Source File: ArrayUtilsTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test @Ignore public void testArrayRemoveIndexX() throws Exception { //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8}); int[] arraySource = new int[] {1,2,3,4,5,6,7,8}; int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{11}); assertEquals(8, dst.length); assertEquals(1, dst[0]); assertEquals(8, dst[7]); }
Example 7
Source File: ArrayUtilsTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test @Ignore public void testArrayRemoveIndex5() throws Exception { //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8}); int[] arraySource = new int[] {1,2,3,4,5,6,7,8}; int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{Integer.MAX_VALUE}); assertEquals(8, dst.length); assertEquals(1, dst[0]); assertEquals(8, dst[7]); }
Example 8
Source File: CudaGridExecutioner.java From nd4j with Apache License 2.0 | 4 votes |
protected void buildZ(Accumulation op, int... dimension) { Arrays.sort(dimension); for (int i = 0; i < dimension.length; i++) { if (dimension[i] < 0) dimension[i] += op.x().rank(); } //do op along all dimensions if (dimension.length == op.x().rank()) dimension = new int[] {Integer.MAX_VALUE}; long[] retShape = Shape.wholeArrayDimension(dimension) ? new long[] {1, 1} : ArrayUtil.removeIndex(op.x().shape(), dimension); //ensure vector is proper shape if (retShape.length == 1) { if (dimension[0] == 0) retShape = new long[] {1, retShape[0]}; else retShape = new long[] {retShape[0], 1}; } else if (retShape.length == 0) { retShape = new long[] {1, 1}; } /* if(op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) return op.noOp(); */ INDArray ret = null; if (op.z() == null || op.z() == op.x()) { if (op.isComplexAccumulation()) { val xT = op.x().tensorssAlongDimension(dimension); val yT = op.y().tensorssAlongDimension(dimension); ret = Nd4j.create(xT, yT); } else { if (Math.abs(op.zeroDouble()) < Nd4j.EPS_THRESHOLD) { ret = Nd4j.zeros(retShape); } else { ret = Nd4j.valueArrayOf(retShape, op.zeroDouble()); } } op.setZ(ret); } else { // compare length if (op.z().lengthLong() != ArrayUtil.prodLong(retShape)) throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]"); if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) { op.z().assign(op.zeroDouble()); } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) { op.z().assign(op.zeroFloat()); } else if (op.x().data().dataType() == DataBuffer.Type.HALF) { op.z().assign(op.zeroHalf()); } ret = op.z(); } }