Java Code Examples for org.nd4j.linalg.api.shape.Shape#isColumnVectorShape()
The following examples show how to use
org.nd4j.linalg.api.shape.Shape#isColumnVectorShape() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseNDArray.java From deeplearning4j with Apache License 2.0 | 6 votes |
protected void assertSlice(INDArray put, long slice) { Preconditions.checkArgument(slice < slices(), "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", slice, slices()); long[] sliceShape = put.shape(); if (Shape.isRowVectorShape(sliceShape)) { return; } else { long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); //no need to compare for scalar; primarily due to shapes either being [1] or length 0 if (put.isScalar()) return; if (isVector() && put.isVector() && put.length() < length()) return; //edge case for column vectors if (Shape.isColumnVectorShape(sliceShape)) return; if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape) && !Shape.isRowVectorShape(sliceShape)) throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ", Arrays.toString(sliceShape), Arrays.toString(requiredShape))); } }
Example 2
Source File: BaseNDArray.java From nd4j with Apache License 2.0 | 5 votes |
/** * Number of columns (shape[1]), throws an exception when * called when not 2d * * @return the number of columns in the array (only 2d) */ @Override public int columns() { // FIXME: int cast if (isMatrix()) return (int) size(1); else if (Shape.isColumnVectorShape(shape())) { return 1; } else if (Shape.isRowVectorShape(shape())) { return (int) length(); } throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid"); }
Example 3
Source File: BaseNDArray.java From nd4j with Apache License 2.0 | 5 votes |
/** * Returns the number of rows * in the array (only 2d) throws an exception when * called when not 2d * * @return the number of rows in the matrix */ @Override public int rows() { // FIXME: if (isMatrix()) return (int) size(0); else if (Shape.isRowVectorShape(shape())) { return 1; } else if (Shape.isColumnVectorShape(shape())) { return (int) length(); } throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid"); }
Example 4
Source File: BaseNDArray.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public int columns() { if (isMatrix()) return (int) size(1); else if (Shape.isColumnVectorShape(shape())) { return 1; } else if (Shape.isRowVectorShape(shape())) { return (int) length(); } throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid"); }
Example 5
Source File: BaseNDArray.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public int rows() { if (isMatrix()) return (int) size(0); else if (Shape.isRowVectorShape(shape())) { return 1; } else if (Shape.isColumnVectorShape(shape())) { return (int) length(); } throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid"); }