Java Code Examples for org.nd4j.linalg.util.ArrayUtil#removeIndex()

The following examples show how to use org.nd4j.linalg.util.ArrayUtil#removeIndex() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Shape.java    From nd4j with Apache License 2.0 6 votes vote down vote up
public static long[] getReducedShape(long[] wholeShape, int[] dimensions) {
    if (isWholeArray(wholeShape, dimensions))
        return new long[] {};
    else if (dimensions.length == 1 && wholeShape.length == 2) {
        val ret = new long[2];
        if (dimensions[0] == 1) {
            ret[0] = wholeShape[0];
            ret[1] = 1;
        } else if (dimensions[0] == 0) {
            ret[0] = 1;
            ret[1] = wholeShape[1];
        }
        return ret;
    }

    return ArrayUtil.removeIndex(wholeShape, dimensions);
}
 
Example 2
Source File: CudaGridExecutioner.java    From nd4j with Apache License 2.0 5 votes vote down vote up
protected void buildZ(IndexAccumulation op, int... dimension) {
    Arrays.sort(dimension);

    for (int i = 0; i < dimension.length; i++) {
        if (dimension[i] < 0)
            dimension[i] += op.x().rank();
    }

    //do op along all dimensions
    if (dimension.length == op.x().rank())
        dimension = new int[] {Integer.MAX_VALUE};


    long[] retShape = Shape.wholeArrayDimension(dimension) ? new long[] {1, 1}
            : ArrayUtil.removeIndex(op.x().shape(), dimension);
    //ensure vector is proper shape
    if (retShape.length == 1) {
        if (dimension[0] == 0)
            retShape = new long[] {1, retShape[0]};
        else
            retShape = new long[] {retShape[0], 1};
    } else if (retShape.length == 0) {
        retShape = new long[] {1, 1};
    }

    if(op.z() == null || op.z() == op.x()){
        INDArray ret = null;
        if (Math.abs(op.zeroDouble()) < Nd4j.EPS_THRESHOLD) {
            ret = Nd4j.zeros(retShape);
        } else {
            ret = Nd4j.valueArrayOf(retShape, op.zeroDouble());
        }

        op.setZ(ret);
    } else if(!Arrays.equals(retShape, op.z().shape())){
        throw new IllegalStateException("Z array shape does not match expected return type for op " + op
                + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(op.z().shape()));
    }
}
 
Example 3
Source File: ArrayUtilsTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testArrayRemoveIndex1() throws Exception {
    //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8});
    int[] arraySource = new int[] {1,2,3,4,5,6,7,8};

    int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{0,1});

    assertEquals(6, dst.length);
    assertEquals(3, dst[0]);
}
 
Example 4
Source File: ArrayUtilsTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testArrayRemoveIndex2() throws Exception {
    //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8});
    int[] arraySource = new int[] {1,2,3,4,5,6,7,8};

    int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{0,7});

    assertEquals(6, dst.length);
    assertEquals(2, dst[0]);
    assertEquals(7, dst[5]);
}
 
Example 5
Source File: ArrayUtilsTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testArrayRemoveIndex4() throws Exception {
    //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8});
    int[] arraySource = new int[] {1,2,3,4,5,6,7,8};

    int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{0});

    assertEquals(7, dst.length);
    assertEquals(2, dst[0]);
    assertEquals(8, dst[6]);
}
 
Example 6
Source File: ArrayUtilsTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testArrayRemoveIndexX() throws Exception {
    //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8});
    int[] arraySource = new int[] {1,2,3,4,5,6,7,8};

    int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{11});

    assertEquals(8, dst.length);
    assertEquals(1, dst[0]);
    assertEquals(8, dst[7]);
}
 
Example 7
Source File: ArrayUtilsTests.java    From nd4j with Apache License 2.0 5 votes vote down vote up
@Test
@Ignore
public void testArrayRemoveIndex5() throws Exception {
    //INDArray arraySource = Nd4j.create(new float[]{1,2,3,4,5,6,7,8});
    int[] arraySource = new int[] {1,2,3,4,5,6,7,8};

    int[] dst = ArrayUtil.removeIndex(arraySource, new int[]{Integer.MAX_VALUE});

    assertEquals(8, dst.length);
    assertEquals(1, dst[0]);
    assertEquals(8, dst[7]);
}
 
Example 8
Source File: CudaGridExecutioner.java    From nd4j with Apache License 2.0 4 votes vote down vote up
protected void buildZ(Accumulation op, int... dimension) {
    Arrays.sort(dimension);

    for (int i = 0; i < dimension.length; i++) {
        if (dimension[i] < 0)
            dimension[i] += op.x().rank();
    }

    //do op along all dimensions
    if (dimension.length == op.x().rank())
        dimension = new int[] {Integer.MAX_VALUE};


    long[] retShape = Shape.wholeArrayDimension(dimension) ? new long[] {1, 1}
            : ArrayUtil.removeIndex(op.x().shape(), dimension);
    //ensure vector is proper shape
    if (retShape.length == 1) {
        if (dimension[0] == 0)
            retShape = new long[] {1, retShape[0]};
        else
            retShape = new long[] {retShape[0], 1};
    } else if (retShape.length == 0) {
        retShape = new long[] {1, 1};
    }

    /*
    if(op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape))
        return op.noOp();
    */

    INDArray ret = null;
    if (op.z() == null || op.z() == op.x()) {
        if (op.isComplexAccumulation()) {
            val xT = op.x().tensorssAlongDimension(dimension);
            val yT = op.y().tensorssAlongDimension(dimension);

            ret = Nd4j.create(xT, yT);
        } else {
            if (Math.abs(op.zeroDouble()) < Nd4j.EPS_THRESHOLD) {
                ret = Nd4j.zeros(retShape);
            } else {
                ret = Nd4j.valueArrayOf(retShape, op.zeroDouble());
            }
        }

        op.setZ(ret);
    } else {
        // compare length
        if (op.z().lengthLong() != ArrayUtil.prodLong(retShape))
            throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]");

        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            op.z().assign(op.zeroDouble());
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            op.z().assign(op.zeroFloat());
        } else if (op.x().data().dataType() == DataBuffer.Type.HALF) {
            op.z().assign(op.zeroHalf());
        }

        ret = op.z();
    }
}