Java Code Examples for org.nd4j.linalg.util.ArrayUtil#argMin()
The following examples show how to use
org.nd4j.linalg.util.ArrayUtil#argMin() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: OpExecutionerUtil.java From nd4j with Apache License 2.0 | 5 votes |
/** * * Choose tensor dimension for operations with one argument: x=Op(x) or similar<br> * When doing some operations in parallel, it is necessary to break up * operations along a dimension to * give a set of 1d tensors. The dimension that this is done on is important for performance reasons; * in summary we want to both minimize the number of tensors * , but also minimize the separation between * elements in the buffer (so the resulting operation is efficient - i.e., avoids cache thrashing). * However, achieving both minimal number * of tensors and are not always possible. * @param x NDArray that we want to split * @return The best dimension to split on */ public static int chooseElementWiseTensorDimension(INDArray x) { if (x.isVector()) return ArrayUtil.argMax(x.shape()); //Execute along the vector //doing argMin(max(x.stride(i),y.stride(i))) minimizes the maximum //separation between elements (helps CPU cache) BUT might result in a huge number //of tiny ops - i.e., addi on NDArrays with shape [5,10^6] int opAlongDimensionMinStride = ArrayUtil.argMin(x.stride()); //doing argMax on shape gives us smallest number of largest tensors //but may not be optimal in terms of element separation (for CPU cache etc) int opAlongDimensionMaxLength = ArrayUtil.argMax(x.shape()); //Edge cases: shapes with 1s in them can have stride of 1 on the dimensions of length 1 if (x.isVector() || x.size(opAlongDimensionMinStride) == 1) return opAlongDimensionMaxLength; //Using a heuristic approach here: basically if we get >= 10x as many tensors using the minimum stride //dimension vs. the maximum size dimension, use the maximum size dimension instead //The idea is to avoid choosing wrong dimension in cases like shape=[10,10^6] //Might be able to do better than this with some additional thought int nOpsAlongMinStride = ArrayUtil.prod(ArrayUtil.removeIndex(x.shape(), opAlongDimensionMinStride)); int nOpsAlongMaxLength = ArrayUtil.prod(ArrayUtil.removeIndex(x.shape(), opAlongDimensionMaxLength)); if (nOpsAlongMinStride <= 10 * nOpsAlongMaxLength) return opAlongDimensionMinStride; else return opAlongDimensionMaxLength; }
Example 2
Source File: ShufflesTests.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testSimpleShuffle1() { INDArray array = Nd4j.zeros(10, 10); for (int x = 0; x < 10; x++) { array.getRow(x).assign(x); } System.out.println(array); OrderScanner2D scanner = new OrderScanner2D(array); assertArrayEquals(new float[]{0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f); System.out.println(); Nd4j.shuffle(array, 1); System.out.println(array); ArrayUtil.argMin(new int[]{}); assertTrue(scanner.compareRow(array)); }
Example 3
Source File: ShufflesTests.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testSymmetricShuffle1() { INDArray features = Nd4j.zeros(10, 10); INDArray labels = Nd4j.zeros(10, 3); for (int x = 0; x < 10; x++) { features.getRow(x).assign(x); labels.getRow(x).assign(x); } System.out.println(features); OrderScanner2D scanner = new OrderScanner2D(features); assertArrayEquals(new float[]{0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f); System.out.println(); List<INDArray> list = new ArrayList<>(); list.add(features); list.add(labels); Nd4j.shuffle(list, 1); System.out.println(features); System.out.println(); System.out.println(labels); ArrayUtil.argMin(new int[]{}); assertTrue(scanner.compareRow(features)); for (int x = 0; x < 10; x++) { double val = features.getRow(x).getDouble(0); INDArray row = labels.getRow(x); for (int y = 0; y < row.length(); y++ ) { assertEquals(val, row.getDouble(y), 0.001); } } }
Example 4
Source File: ShufflesTests.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testSimpleShuffle1() { INDArray array = Nd4j.zeros(10, 10); for (int x = 0; x < 10; x++) { array.getRow(x).assign(x); } System.out.println(array); OrderScanner2D scanner = new OrderScanner2D(array); assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f); System.out.println(); Nd4j.shuffle(array, 1); System.out.println(array); ArrayUtil.argMin(new int[] {}); assertTrue(scanner.compareRow(array)); }
Example 5
Source File: ShufflesTests.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testSimpleShuffle3() { INDArray array = Nd4j.zeros(11, 10); for (int x = 0; x < 11; x++) { array.getRow(x).assign(x); } System.out.println(array); OrderScanner2D scanner = new OrderScanner2D(array); assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f}, scanner.getMap(), 0.01f); System.out.println(); Nd4j.shuffle(array, 1); System.out.println(array); ArrayUtil.argMin(new int[] {}); assertTrue(scanner.compareRow(array)); }
Example 6
Source File: ShufflesTests.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testSymmetricShuffle1() { INDArray features = Nd4j.zeros(10, 10); INDArray labels = Nd4j.zeros(10, 3); for (int x = 0; x < 10; x++) { features.getRow(x).assign(x); labels.getRow(x).assign(x); } System.out.println(features); OrderScanner2D scanner = new OrderScanner2D(features); assertArrayEquals(new float[] {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f}, scanner.getMap(), 0.01f); System.out.println(); List<INDArray> list = new ArrayList<>(); list.add(features); list.add(labels); Nd4j.shuffle(list, 1); System.out.println(features); System.out.println(); System.out.println(labels); ArrayUtil.argMin(new int[] {}); assertTrue(scanner.compareRow(features)); for (int x = 0; x < 10; x++) { double val = features.getRow(x).getDouble(0); INDArray row = labels.getRow(x); for (int y = 0; y < row.length(); y++) { assertEquals(val, row.getDouble(y), 0.001); } } }