Java Code Examples for org.nd4j.linalg.api.shape.Shape#getOrder()
The following examples show how to use
org.nd4j.linalg.api.shape.Shape#getOrder() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseNDArray.java From nd4j with Apache License 2.0 | 5 votes |
@Override public INDArray subArray(ShapeOffsetResolution resolution) { Nd4j.getCompressor().autoDecompress(this); long[] offsets = resolution.getOffsets(); int[] shape = LongUtils.toInts(resolution.getShapes()); int[] stride = LongUtils.toInts(resolution.getStrides()); // if (offset() + resolution.getOffset() >= Integer.MAX_VALUE) // throw new IllegalArgumentException("Offset of array can not be >= Integer.MAX_VALUE"); long offset = (offset() + resolution.getOffset()); int n = shape.length; // FIXME: shapeInfo should be used here if (shape.length < 1) return create(Nd4j.createBufferDetached(shape)); if (offsets.length != n) throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets)); if (stride.length != n) throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride)); if (shape.length == rank() && Shape.contentEquals(shape, shapeOf())) { if (ArrayUtil.isZero(offsets)) { return this; } else { throw new IllegalArgumentException("Invalid subArray offsets"); } } char newOrder = Shape.getOrder(shape, stride, 1); return create(data, Arrays.copyOf(shape, shape.length), stride, offset, newOrder); }
Example 2
Source File: ShapeTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testShapeOrder(){ long[] shape = {2,2}; long[] stride = {1,8}; //Ascending strides -> F order char order = Shape.getOrder(shape, stride, 1); assertEquals('f', order); }
Example 3
Source File: TADTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testTADEWSStride(){ INDArray orig = Nd4j.linspace(1, 600, 600).reshape('f', 10, 1, 60); for( int i=0; i<60; i++ ){ INDArray tad = orig.tensorAlongDimension(i, 0, 1); //TAD: should be equivalent to get(all, all, point(i)) INDArray get = orig.get(all(), all(), point(i)); String str = String.valueOf(i); assertEquals(str, get, tad); assertEquals(str, get.data().offset(), tad.data().offset()); assertEquals(str, get.elementWiseStride(), tad.elementWiseStride()); char orderTad = Shape.getOrder(tad.shape(), tad.stride(), 1); char orderGet = Shape.getOrder(get.shape(), get.stride(), 1); assertEquals('f', orderTad); assertEquals('f', orderGet); long ewsTad = Shape.elementWiseStride(tad.shape(), tad.stride(), tad.ordering() == 'f'); long ewsGet = Shape.elementWiseStride(get.shape(), get.stride(), get.ordering() == 'f'); assertEquals(1, ewsTad); assertEquals(1, ewsGet); } }
Example 4
Source File: BaseNDArray.java From nd4j with Apache License 2.0 | 4 votes |
@Override public INDArray tensorAlongDimension(int index, int... dimension) { if (dimension == null || dimension.length == 0) throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) return this; for (int i = 0; i < dimension.length; i++) if (dimension[i] < 0) dimension[i] += rank(); //dedup if (dimension.length > 1) dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension)))); if (dimension.length > 1) { Arrays.sort(dimension); } long tads = tensorssAlongDimension(dimension); if (index >= tads) throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); if (dimension.length == 1) { if (dimension[0] == 0 && isColumnVector()) { return this.transpose(); } else if (dimension[0] == 1 && isRowVector()) { return this; } } Pair<DataBuffer, DataBuffer> tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension); DataBuffer shapeInfo = tadInfo.getFirst(); val shape = Shape.shape(shapeInfo); val stride = Shape.stride(shapeInfo).asLong(); long offset = offset() + tadInfo.getSecond().getLong(index); INDArray toTad = Nd4j.create(data(), shape, stride, offset); BaseNDArray baseNDArray = (BaseNDArray) toTad; //preserve immutability char newOrder = Shape.getOrder(shape, stride, 1); int ews = baseNDArray.shapeInfoDataBuffer().getInt(baseNDArray.shapeInfoDataBuffer().length() - 2); //TAD always calls permute. Permute EWS is always -1. This is not true // for row vector shapes though. if (!Shape.isRowVectorShape(baseNDArray.shapeInfoDataBuffer())) ews = -1; // we create new shapeInfo with possibly new ews & order /** * NOTE HERE THAT ZERO IS PRESET FOR THE OFFSET AND SHOULD STAY LIKE THAT. * Zero is preset for caching purposes. * We don't actually use the offset defined in the * shape info data buffer. * We calculate and cache the offsets separately. * */ baseNDArray.setShapeInformation( Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, ews, newOrder)); return toTad; }
Example 5
Source File: BaseNDArray.java From nd4j with Apache License 2.0 | 4 votes |
/** * An <b>in-place</b> version of permute. The array shape information (shape, strides) * is modified by this operation (but not the data itself) * See: http://www.mathworks.com/help/matlab/ref/permute.html * * @param rearrange the dimensions to swap to * @return the current array */ @Override public INDArray permutei(int... rearrange) { boolean alreadyInOrder = true; val shapeInfo = shapeInfo(); int rank = Shape.rank(javaShapeInformation); for (int i = 0; i < rank; i++) { if (rearrange[i] != i) { alreadyInOrder = false; break; } } if (alreadyInOrder) return this; checkArrangeArray(rearrange); val newShape = doPermuteSwap(Shape.shapeOf(shapeInfo), rearrange); val newStride = doPermuteSwap(Shape.stride(shapeInfo), rearrange); char newOrder = Shape.getOrder(newShape, newStride, elementStride()); //Set the shape information of this array: shape, stride, order. //Shape info buffer: [rank, [shape], [stride], offset, elementwiseStride, order] /*for( int i=0; i<rank; i++ ){ shapeInfo.put(1+i,newShape[i]); shapeInfo.put(1+i+rank,newStride[i]); } shapeInfo.put(3+2*rank,newOrder); */ val ews = shapeInfo.get(2 * rank + 2); /* if (ews < 1 && !attemptedToFindElementWiseStride) throw new RuntimeException("EWS is -1"); */ val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, 0, ews, newOrder); setShapeInformation(si); if (shapeInfo.get(2 * rank + 2) > 0) { //for the backend to work - no ews for permutei //^^ not true anymore? Not sure here. Marking this for raver setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, this.offset(), -1, newOrder)); } //this.shape = null; //this.stride = null; return this; }