Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#getColumns()
The following examples show how to use
org.nd4j.linalg.api.ndarray.INDArray#getColumns() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: PruningTest.java From ml-models with Apache License 2.0 | 6 votes |
@Test public void testGetCols() { INDArray origEmbedding = Nd4j.create(new double[][]{ {0.00, 1.00, 0.00}, {0.00, 0.00, 1.00}, {0.00, 1.00, 1.00}, {0.00, 2.00, 2.00}, {1.00, 0.00, 0.00}, {1.00, 0.00, 0.00}, {2.00, 0.00, 0.00}, }); int[] featIdsToKeep = {2, 1, 0}; INDArray ndPrunedEmbedding = Nd4j.create(origEmbedding.shape()); Nd4j.copy(origEmbedding, ndPrunedEmbedding); INDArray columns = ndPrunedEmbedding.getColumns(featIdsToKeep); System.out.println("columns = \n" + columns); }
Example 2
Source File: ZeroShotUtil.java From AILibs with GNU Affero General Public License v3.0 | 5 votes |
public static INDArray unscaleParameters(INDArray parameters, DyadMinMaxScaler scaler, int numHyperPars) { int[] hyperParIndices = new int[numHyperPars]; for (int i = 0; i < numHyperPars; i++) { hyperParIndices[i] = (int) parameters.length() - numHyperPars + i; } INDArray unscaled = parameters.getColumns(hyperParIndices); for (int i = 0; i < unscaled.length(); i++) { unscaled.putScalar(i, unscaled.getDouble(i) * (scaler.getStatsY()[i].getMax() - scaler.getStatsY()[i].getMin()) + scaler.getStatsY()[i].getMin()); } return unscaled; }
Example 3
Source File: IndexingTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testVectorIndexing() { INDArray x = Nd4j.linspace(0, 10, 11); int[] index = new int[] {5, 8, 9}; INDArray columnsTest = x.getColumns(index); assertEquals(Nd4j.create(new double[] {5, 8, 9}), columnsTest); int[] index2 = new int[] {2, 2, 4}; //retrieve the same columns twice INDArray columnsTest2 = x.getColumns(index2); assertEquals(Nd4j.create(new double[] {2, 2, 4}), columnsTest2); }
Example 4
Source File: IndexingTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testGetRowsColumnsMatrix() { INDArray arr = Nd4j.linspace(1, 24, 24).reshape(4, 6); INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}}); System.out.println(arr); INDArray firstAndSecondColumns = arr.getColumns(0, 1); assertEquals(firstAndSecondColumnsAssertion, firstAndSecondColumns); INDArray firstAndSecondRows = Nd4j.create(new double[][] {{1.00, 5.00, 9.00, 13.00, 17.00, 21.00}, {1.00, 5.00, 9.00, 13.00, 17.00, 21.00}, {2.00, 6.00, 10.00, 14.00, 18.00, 22.00}}); INDArray rows = arr.getRows(new int[] {0, 0, 1}); assertEquals(firstAndSecondRows, rows); }
Example 5
Source File: NDArrayTestsFortran.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testGetColumns() { INDArray matrix = Nd4j.linspace(1, 6, 6).reshape(2, 3); log.info("Original: {}", matrix); INDArray matrixGet = matrix.getColumns(new int[] {1, 2}); INDArray matrixAssertion = Nd4j.create(new double[][] {{3, 5}, {4, 6}}); log.info("order A: {}", Arrays.toString(matrixAssertion.shapeInfoDataBuffer().asInt())); log.info("order B: {}", Arrays.toString(matrixGet.shapeInfoDataBuffer().asInt())); log.info("data A: {}", Arrays.toString(matrixAssertion.data().asFloat())); log.info("data B: {}", Arrays.toString(matrixGet.data().asFloat())); assertEquals(matrixAssertion, matrixGet); }
Example 6
Source File: IndexingTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testVectorIndexing() { INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE); int[] index = new int[] {5, 8, 9}; INDArray columnsTest = x.getColumns(index); assertEquals(Nd4j.create(new double[] {5, 8, 9}, new int[]{1,3}), columnsTest); int[] index2 = new int[] {2, 2, 4}; //retrieve the same columns twice INDArray columnsTest2 = x.getColumns(index2); assertEquals(Nd4j.create(new double[] {2, 2, 4}, new int[]{1,3}), columnsTest2); }
Example 7
Source File: IndexingTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testGetRowsColumnsMatrix() { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6); INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}}); // System.out.println(arr); INDArray firstAndSecondColumns = arr.getColumns(0, 1); assertEquals(firstAndSecondColumnsAssertion, firstAndSecondColumns); INDArray firstAndSecondRows = Nd4j.create(new double[][] {{1.00, 5.00, 9.00, 13.00, 17.00, 21.00}, {1.00, 5.00, 9.00, 13.00, 17.00, 21.00}, {2.00, 6.00, 10.00, 14.00, 18.00, 22.00}}); INDArray rows = arr.getRows(0, 0, 1); assertEquals(firstAndSecondRows, rows); }
Example 8
Source File: NDArrayTestsFortran.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testGetColumns() { INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE); // log.info("Original: {}", matrix); INDArray matrixGet = matrix.getColumns(1, 2); INDArray matrixAssertion = Nd4j.create(new double[][] {{3, 5}, {4, 6}}); // log.info("order A: {}", Arrays.toString(matrixAssertion.shapeInfoDataBuffer().asInt())); // log.info("order B: {}", Arrays.toString(matrixGet.shapeInfoDataBuffer().asInt())); // log.info("data A: {}", Arrays.toString(matrixAssertion.data().asFloat())); // log.info("data B: {}", Arrays.toString(matrixGet.data().asFloat())); assertEquals(matrixAssertion, matrixGet); }
Example 9
Source File: Pruning.java From ml-models with Apache License 2.0 | 4 votes |
private INDArray pruneEmbedding(INDArray origEmbedding, int... featIdsToKeep) { INDArray ndPrunedEmbedding = Nd4j.create(origEmbedding.shape()); Nd4j.copy(origEmbedding, ndPrunedEmbedding); return ndPrunedEmbedding.getColumns(featIdsToKeep); }
Example 10
Source File: InputOptListener.java From AILibs with GNU Affero General Public License v3.0 | 4 votes |
public void reportOptimizationStep(INDArray plNetInput, double plNetOutput) { INDArray inpToAdd = plNetInput.getColumns(indicesToWatch); inputList.add(inpToAdd); outputList.add(plNetOutput); }