Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#tensorssAlongDimension()
The following examples show how to use
org.nd4j.linalg.api.ndarray.INDArray#tensorssAlongDimension() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ShapeTestsC.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testSixteenZeroOne() { INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); assertEquals(4, baseArr.tensorssAlongDimension(0, 1)); INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 5}, {9, 13}}); INDArray columnVectorSecond = Nd4j.create(new double[][] {{2, 6}, {10, 14}}); INDArray columnVectorThird = Nd4j.create(new double[][] {{3, 7}, {11, 15}}); INDArray columnVectorFourth = Nd4j.create(new double[][] {{4, 8}, {12, 16}}); INDArray[] assertions = new INDArray[] {columnVectorFirst, columnVectorSecond, columnVectorThird, columnVectorFourth}; for (int i = 0; i < baseArr.tensorssAlongDimension(0, 1); i++) { INDArray test = baseArr.tensorAlongDimension(i, 0, 1); assertEquals("Wrong at index " + i, assertions[i], test); } }
Example 2
Source File: ShapeTestsC.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testSixteenSecondDim() { INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 3}), Nd4j.create(new double[] {2, 4}), Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {6, 8}), Nd4j.create(new double[] {9, 11}), Nd4j.create(new double[] {10, 12}), Nd4j.create(new double[] {13, 15}), Nd4j.create(new double[] {14, 16}), }; for (int i = 0; i < baseArr.tensorssAlongDimension(2); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 2); assertEquals("Failed at index " + i, assertions[i], arr); } }
Example 3
Source File: ShapeTestsC.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testSixteenFirstDim() { INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 5}), Nd4j.create(new double[] {2, 6}), Nd4j.create(new double[] {3, 7}), Nd4j.create(new double[] {4, 8}), Nd4j.create(new double[] {9, 13}), Nd4j.create(new double[] {10, 14}), Nd4j.create(new double[] {11, 15}), Nd4j.create(new double[] {12, 16}), }; for (int i = 0; i < baseArr.tensorssAlongDimension(1); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 1); assertEquals("Failed at index " + i, assertions[i], arr); } }
Example 4
Source File: ShapeTests.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testSixteenZeroOne() { INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); assertEquals(4, baseArr.tensorssAlongDimension(0, 1)); INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 3}, {2, 4}}); INDArray columnVectorSecond = Nd4j.create(new double[][] {{9, 11}, {10, 12}}); INDArray columnVectorThird = Nd4j.create(new double[][] {{5, 7}, {6, 8}}); INDArray columnVectorFourth = Nd4j.create(new double[][] {{13, 15}, {14, 16}}); INDArray[] assertions = new INDArray[] {columnVectorFirst, columnVectorSecond, columnVectorThird, columnVectorFourth}; for (int i = 0; i < baseArr.tensorssAlongDimension(0, 1); i++) { INDArray test = baseArr.tensorAlongDimension(i, 0, 1); assertEquals("Wrong at index " + i, assertions[i], test); } }
Example 5
Source File: ShapeTests.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testSixteenSecondDim() { INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 5}), Nd4j.create(new double[] {9, 13}), Nd4j.create(new double[] {3, 7}), Nd4j.create(new double[] {11, 15}), Nd4j.create(new double[] {2, 6}), Nd4j.create(new double[] {10, 14}), Nd4j.create(new double[] {4, 8}), Nd4j.create(new double[] {12, 16}), }; for (int i = 0; i < baseArr.tensorssAlongDimension(2); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 2); assertEquals("Failed at index " + i, assertions[i], arr); } }
Example 6
Source File: ShapeTests.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testSixteenFirstDim() { INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 3}), Nd4j.create(new double[] {9, 11}), Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {13, 15}), Nd4j.create(new double[] {2, 4}), Nd4j.create(new double[] {10, 12}), Nd4j.create(new double[] {6, 8}), Nd4j.create(new double[] {14, 16}), }; for (int i = 0; i < baseArr.tensorssAlongDimension(1); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 1); assertEquals("Failed at index " + i, assertions[i], arr); } }
Example 7
Source File: OpExecutionerUtil.java From nd4j with Apache License 2.0 | 6 votes |
/** Tensor1DStats, used to efficiently iterate through tensors on a matrix (2d NDArray) for element-wise ops * For example, the offset of each 1d tensor can be calculated using only a single tensorAlongDimension method call, * hence is potentially faster than approaches requiring multiple tensorAlongDimension calls.<br> * Note that this can only (generally) be used for 2d NDArrays. For certain 3+d NDArrays, the tensor starts may not * be in increasing order */ public static Tensor1DStats get1DTensorStats(INDArray array, int... dimension) { long tensorLength = array.size(dimension[0]); //As per tensorssAlongDimension: long numTensors = array.tensorssAlongDimension(dimension); //First tensor always starts with the first element in the NDArray, regardless of dimension long firstTensorOffset = array.offset(); //Next: Need to work out the separation between the start (first element) of each 1d tensor long tensorStartSeparation; int elementWiseStride; //Separation in buffer between elements in the tensor if (numTensors == 1) { tensorStartSeparation = -1; //Not applicable elementWiseStride = array.elementWiseStride(); } else { INDArray secondTensor = array.tensorAlongDimension(1, dimension); tensorStartSeparation = secondTensor.offset() - firstTensorOffset; elementWiseStride = secondTensor.elementWiseStride(); } return new Tensor1DStats(firstTensorOffset, tensorStartSeparation, numTensors, tensorLength, elementWiseStride); }
Example 8
Source File: OpExecutionerTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testTad() { INDArray arr = Nd4j.linspace(1, 12, 12).reshape(2, 3, 2); for (int i = 0; i < arr.tensorssAlongDimension(0); i++) { System.out.println(arr.tensorAlongDimension(i, 0)); } }
Example 9
Source File: LoneTest.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testFlattenedView() { int rows = 8; int cols = 8; int dim2 = 4; int length = rows * cols; int length3d = rows * cols * dim2; INDArray first = Nd4j.linspace(1, length, length).reshape('c', rows, cols); INDArray second = Nd4j.create(new int[]{rows, cols}, 'f').assign(first); INDArray third = Nd4j.linspace(1, length3d, length3d).reshape('c', rows, cols, dim2); first.addi(0.1); second.addi(0.2); third.addi(0.3); first = first.get(NDArrayIndex.interval(4, 8), NDArrayIndex.interval(0, 2, 8)); for (int i = 0; i < first.tensorssAlongDimension(0); i++) { System.out.println(first.tensorAlongDimension(i, 0)); } for (int i = 0; i < first.tensorssAlongDimension(1); i++) { System.out.println(first.tensorAlongDimension(i, 1)); } second = second.get(NDArrayIndex.interval(3, 7), NDArrayIndex.all()); third = third.permute(0, 2, 1); INDArray cAssertion = Nd4j.create(new double[]{33.10, 35.10, 37.10, 39.10, 41.10, 43.10, 45.10, 47.10, 49.10, 51.10, 53.10, 55.10, 57.10, 59.10, 61.10, 63.10}); INDArray fAssertion = Nd4j.create(new double[]{33.10, 41.10, 49.10, 57.10, 35.10, 43.10, 51.10, 59.10, 37.10, 45.10, 53.10, 61.10, 39.10, 47.10, 55.10, 63.10}); assertEquals(cAssertion, Nd4j.toFlattened('c', first)); assertEquals(fAssertion, Nd4j.toFlattened('f', first)); }
Example 10
Source File: DataSetUtil.java From nd4j with Apache License 2.0 | 5 votes |
public static INDArray tailor4d2d(@NonNull INDArray data) { long instances = data.size(0); long channels = data.size(1); long height = data.size(2); long width = data.size(3); INDArray in2d = Nd4j.create(channels, height * width * instances); long tads = data.tensorssAlongDimension(3, 2, 0); for (int i = 0; i < tads; i++) { INDArray thisTAD = data.tensorAlongDimension(i, 3, 2, 0); in2d.putRow(i, Nd4j.toFlattened(thisTAD)); } return in2d.transposei(); }
Example 11
Source File: ConcatTestsC.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testConcat3d() { INDArray first = Nd4j.linspace(1, 24, 24).reshape('c', 2, 3, 4); INDArray second = Nd4j.linspace(24, 36, 12).reshape('c', 1, 3, 4); INDArray third = Nd4j.linspace(36, 48, 12).reshape('c', 1, 3, 4); //ConcatV2, dim 0 INDArray exp = Nd4j.create(2 + 1 + 1, 3, 4); exp.put(new INDArrayIndex[] {NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all()}, first); exp.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all()}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.all()}, third); INDArray concat0 = Nd4j.concat(0, first, second, third); assertEquals(exp, concat0); //ConcatV2, dim 1 second = Nd4j.linspace(24, 32, 8).reshape('c', 2, 1, 4); for (int i = 0; i < second.tensorssAlongDimension(1); i++) { INDArray secondTad = second.javaTensorAlongDimension(i, 1); System.out.println(second.tensorAlongDimension(i, 1)); } third = Nd4j.linspace(32, 48, 16).reshape('c', 2, 2, 4); exp = Nd4j.create(2, 3 + 1 + 2, 4); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all()}, first); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(3), NDArrayIndex.all()}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(4, 6), NDArrayIndex.all()}, third); INDArray concat1 = Nd4j.concat(1, first, second, third); assertEquals(exp, concat1); //ConcatV2, dim 2 second = Nd4j.linspace(24, 36, 12).reshape('c', 2, 3, 2); third = Nd4j.linspace(36, 42, 6).reshape('c', 2, 3, 1); exp = Nd4j.create(2, 3, 4 + 2 + 1); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, first); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, 6)}, second); exp.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(6)}, third); INDArray concat2 = Nd4j.concat(2, first, second, third); assertEquals(exp, concat2); }