Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#get()
The following examples show how to use
org.nd4j.linalg.api.ndarray.INDArray#get() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: IndexingTestsC.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testPointIndexes() { INDArray arr = Nd4j.create(4, 3, 2); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all()); assertArrayEquals(new int[] {4, 2}, get.shape()); INDArray linspaced = Nd4j.linspace(1, 24, 24).reshape(4, 3, 2); INDArray assertion = Nd4j.create(new double[][] {{3, 4}, {9, 10}, {15, 16}, {21, 22}}); INDArray linspacedGet = linspaced.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all()); for (int i = 0; i < linspacedGet.slices(); i++) { INDArray sliceI = linspacedGet.slice(i); assertEquals(assertion.slice(i), sliceI); } assertArrayEquals(new int[] {6, 1}, linspacedGet.stride()); assertEquals(assertion, linspacedGet); }
Example 2
Source File: BatchedInferenceObservable.java From deeplearning4j with Apache License 2.0 | 6 votes |
private INDArray[] splitExamples(INDArray netOutput, int firstInputComponent, int lastInputComponent){ int numSplits = lastInputComponent - firstInputComponent + 1; if(numSplits == 1){ return new INDArray[]{netOutput}; } else { INDArray[] out = new INDArray[numSplits]; INDArrayIndex[] indices = new INDArrayIndex[netOutput.rank()]; for(int i=1; i<indices.length; i++ ){ indices[i] = NDArrayIndex.all(); } int examplesSoFar = 0; for( int inNum = 0; inNum < numSplits; inNum++ ){ val inSizeEx = inputs.get(firstInputComponent + inNum)[0].size(0); indices[0] = NDArrayIndex.interval(examplesSoFar, examplesSoFar+inSizeEx); out[inNum] = netOutput.get(indices); examplesSoFar += inSizeEx; } return out; } }
Example 3
Source File: NDArrayCreationUtil.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static List<Pair<INDArray, String>> getSubMatricesWithShape(char ordering, long rows, long cols, long seed, DataType dataType) { //Create 3 identical matrices. Could do get() on single original array, but in-place modifications on one //might mess up tests for another Nd4j.getRandom().setSeed(seed); long[] shape = new long[] {2 * rows + 4, 2 * cols + 4}; int len = ArrayUtil.prod(shape); INDArray orig = Nd4j.linspace(1, len, len, dataType).reshape(ordering, shape); INDArray first = orig.get(NDArrayIndex.interval(0, rows), NDArrayIndex.interval(0, cols)); Nd4j.getRandom().setSeed(seed); orig = Nd4j.linspace(1, len, len, dataType).reshape(shape); INDArray second = orig.get(NDArrayIndex.interval(3, rows + 3), NDArrayIndex.interval(3, cols + 3)); Nd4j.getRandom().setSeed(seed); orig = Nd4j.linspace(1, len, len, dataType).reshape(ordering, shape); INDArray third = orig.get(NDArrayIndex.interval(rows, 2 * rows), NDArrayIndex.interval(cols, 2 * cols)); String baseMsg = "getSubMatricesWithShape(" + rows + "," + cols + "," + seed + ")"; List<Pair<INDArray, String>> list = new ArrayList<>(3); list.add(new Pair<>(first, baseMsg + ".get(0)")); list.add(new Pair<>(second, baseMsg + ".get(1)")); list.add(new Pair<>(third, baseMsg + ".get(2)")); return list; }
Example 4
Source File: AdaDeltaUpdater.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { if (!viewArray.isRowVector()) throw new IllegalArgumentException("Invalid input: expect row vector input"); if (initialize) viewArray.assign(0); long length = viewArray.length(); this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); //Reshape to match the expected shape of the input gradient arrays this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f'); this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f'); if (msg == null || msdx == null) throw new IllegalStateException("Could not correctly reshape gradient view arrays"); }
Example 5
Source File: GaussianReconstructionDistribution.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public INDArray generateRandom(INDArray preOutDistributionParams) { INDArray output = preOutDistributionParams.dup(); activationFn.getActivation(output, true); val size = output.size(1) / 2; INDArray mean = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size)); INDArray logStdevSquared = output.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)); INDArray sigma = Transforms.exp(logStdevSquared, true); Transforms.sqrt(sigma, false); INDArray e = Nd4j.randn(sigma.shape()); return e.muli(sigma).addi(mean); //mu + sigma * N(0,1) ~ N(mu,sigma^2) }
Example 6
Source File: UpdaterBlock.java From deeplearning4j with Apache License 2.0 | 5 votes |
private void update(int iteration, int epoch, boolean externalGradient, INDArray fullNetworkGradientView, INDArray fullNetworkParamsArray) { //Initialize the updater, if necessary if (gradientUpdater == null) { init(); } INDArray blockGradViewArray; if (externalGradient) { blockGradViewArray = fullNetworkGradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(paramOffsetStart, paramOffsetEnd)); } else { blockGradViewArray = gradientView; } //First: Pre-apply gradient clipping etc: some are done on a per-layer basis //Therefore: it's already done by this point, in MultiLayerUpdater or ComputationGraphUpdater //Second: apply learning rate policy. Note that by definition we have the same LR policy for every single // variable in the block Trainable l0 = layersAndVariablesInBlock.get(0).getLayer(); if (l0.numParams() == 0) { //No params for this layer return; } //Pre-updater regularization: l1 and l2 applyRegularizationAllVariables(Regularization.ApplyStep.BEFORE_UPDATER, iteration, epoch, externalGradient, fullNetworkGradientView, fullNetworkParamsArray); //Apply the updater itself gradientUpdater.applyUpdater(blockGradViewArray, iteration, epoch); //Post updater regularization: weight decay applyRegularizationAllVariables(Regularization.ApplyStep.POST_UPDATER, iteration, epoch, externalGradient, fullNetworkGradientView, fullNetworkParamsArray); }
Example 7
Source File: CompositeReconstructionDistribution.java From deeplearning4j with Apache License 2.0 | 5 votes |
private INDArray randomSample(INDArray preOutDistributionParams, boolean isMean) { int inputSoFar = 0; int paramsSoFar = 0; INDArray out = Nd4j.createUninitialized(preOutDistributionParams.dataType(), new long[] {preOutDistributionParams.size(0), totalSize}); for (int i = 0; i < distributionSizes.length; i++) { int thisDataSize = distributionSizes[i]; int thisParamsSize = reconstructionDistributions[i].distributionInputSize(thisDataSize); INDArray paramsSubset = preOutDistributionParams.get(NDArrayIndex.all(), NDArrayIndex.interval(paramsSoFar, paramsSoFar + thisParamsSize)); INDArray thisRandomSample; if (isMean) { thisRandomSample = reconstructionDistributions[i].generateAtMean(paramsSubset); } else { thisRandomSample = reconstructionDistributions[i].generateRandom(paramsSubset); } out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(inputSoFar, inputSoFar + thisDataSize)}, thisRandomSample); inputSoFar += thisDataSize; paramsSoFar += thisParamsSize; } return out; }
Example 8
Source File: PLNetLoss.java From AILibs with GNU Affero General Public License v3.0 | 5 votes |
/** * Computes the NLL for PL networks according to equation (27) in [1]. * * @param plNetOutputs The outputs for M_n dyads generated by a PLNet's output layer in order of their ranking (from best to worst). * @return The NLL loss for the given PLNet outputs. */ public static INDArray computeLoss(INDArray plNetOutputs) { if (!(plNetOutputs.isRowVector()) || plNetOutputs.size(1) < 2 ) { throw new IllegalArgumentException("Input has to be a row vector of 2 or more elements."); } long dyadRankingLength = plNetOutputs.size(1); double loss = 0; for (int m = 0; m <= dyadRankingLength - 2; m++) { INDArray innerSumSlice = plNetOutputs.get(NDArrayIndex.interval(m, dyadRankingLength)); innerSumSlice = Transforms.exp(innerSumSlice); loss += Transforms.log(innerSumSlice.sum(1)).getDouble(0); } loss -= plNetOutputs.get(NDArrayIndex.interval(0, dyadRankingLength - 1)).sum(1).getDouble(0); return Nd4j.create(new double[]{loss}); }
Example 9
Source File: IndexingTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testSmallInterval() { INDArray arr = Nd4j.arange(8).reshape(2, 2, 2); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); INDArray rest = arr.get(interval(1, 2), all(), all()); assertEquals(assertion, rest); }
Example 10
Source File: IndexingTests.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testGetRows() { INDArray arr = Nd4j.linspace(1, 9, 9).reshape(3, 3); INDArray testAssertion = Nd4j.create(new double[][] {{5, 8}, {6, 9}}); INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(1, 2)); assertEquals(testAssertion, test); }
Example 11
Source File: IndexingTestsC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testGetScalar() { INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray d = arr.get(point(1)); assertTrue(d.isScalar()); assertEquals(2.0, d.getDouble(0), 1e-1); }
Example 12
Source File: IndexingTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testMultiRow() { INDArray matrix = Nd4j.linspace(1, 9, 9).reshape(3, 3); INDArray assertion = Nd4j.create(new double[][] {{4, 7}}); INDArray test = matrix.get(new SpecifiedIndex(1, 2), NDArrayIndex.interval(0, 1)); assertEquals(assertion, test); }
Example 13
Source File: IndexingTestsC.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testVectorIndexing() { INDArray arr = Nd4j.linspace(1, 10, 10); INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5}); INDArray viewTest = arr.get(point(0), interval(1, 5)); assertEquals(assertion, viewTest); }
Example 14
Source File: SameDiffRNNTestCases.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public MultiDataSetIterator getEvaluationTestData() throws Exception { int miniBatchSize = 10; int numLabelClasses = 6; // File featuresDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/features/").getFile(); // File labelsDirTest = new ClassPathResource("/RnnCsvSequenceClassification/uci_seq/test/labels/").getFile(); File featuresDirTest = Files.createTempDir(); File labelsDirTest = Files.createTempDir(); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/features/", featuresDirTest); Resources.copyDirectory("dl4j-integration-tests/data/uci_seq/test/labels/", labelsDirTest); SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(); trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); SequenceRecordReader trainLabels = new CSVSequenceRecordReader(); trainLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149)); DataSetIterator testData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); MultiDataSetIterator iter = new MultiDataSetIteratorAdapter(testData); MultiDataSetPreProcessor pp = multiDataSet -> { INDArray l = multiDataSet.getLabels(0); l = l.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(l.size(2) - 1)); multiDataSet.setLabels(0, l); multiDataSet.setLabelsMaskArray(0, null); }; iter.setPreProcessor(new CompositeMultiDataSetPreProcessor(getNormalizer(), pp)); return iter; }
Example 15
Source File: SameDiffParamInitializer.java From deeplearning4j with Apache License 2.0 | 5 votes |
public Map<String,INDArray> subsetAndReshape(List<String> params, Map<String,long[]> paramShapes, INDArray view, AbstractSameDiffLayer sdl, SameDiffVertex sdv){ Class<?> clazz = (sdl != null ? sdl.getClass() : sdv.getClass()); String layerName = (sdl != null ? sdl.getLayerName() : ""); //TODO Map<String,INDArray> out = new LinkedHashMap<>(); int soFar = 0; for(String s : params){ val sh = paramShapes.get(s); val length = ArrayUtil.prodLong(sh); if(length <= 0){ throw new IllegalStateException("Invalid array state for parameter \"" + s + "\" in layer " + layerName + " of type " + clazz.getSimpleName() + ": parameter length (" + length + ") must be > 0 - parameter array shape: " + Arrays.toString(sh)); } INDArray sub = view.get(interval(0,0,true), interval(soFar, soFar + length)); if(!Arrays.equals(sub.shape(), sh)){ char order = (sdl != null ? sdl.paramReshapeOrder(s) : sdv.paramReshapeOrder(s)); sub = sub.reshape(order, sh); } out.put(s, sub); soFar += length; } return out; }
Example 16
Source File: ComputationGraphTestRNN.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testRnnTimeStepGravesLSTM() { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 12; //4 layer network: 2 GravesLSTM + DenseLayer + RnnOutputLayer. Hence also tests preprocessors. ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(5).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build(), "in") .addLayer("1", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build(), "0") .addLayer("2", new DenseLayer.Builder().nIn(8).nOut(9).activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)) .build(), "1") .addLayer("3", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .nIn(9).nOut(4) .activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)).build(), "2") .setOutputs("3").inputPreProcessor("2", new RnnToFeedForwardPreProcessor()) .inputPreProcessor("3", new FeedForwardToRnnPreProcessor()) .build(); ComputationGraph graph = new ComputationGraph(conf); graph.init(); INDArray input = Nd4j.rand(new int[] {3, 5, timeSeriesLength}); Map<String, INDArray> allOutputActivations = graph.feedForward(input, true); INDArray fullOutL0 = allOutputActivations.get("0"); INDArray fullOutL1 = allOutputActivations.get("1"); INDArray fullOutL3 = allOutputActivations.get("3"); assertArrayEquals(new long[] {3, 7, timeSeriesLength}, fullOutL0.shape()); assertArrayEquals(new long[] {3, 8, timeSeriesLength}, fullOutL1.shape()); assertArrayEquals(new long[] {3, 4, timeSeriesLength}, fullOutL3.shape()); int[] inputLengths = {1, 2, 3, 4, 6, 12}; //Do steps of length 1, then of length 2, ..., 12 //Should get the same result regardless of step size; should be identical to standard forward pass for (int i = 0; i < inputLengths.length; i++) { int inLength = inputLengths[i]; int nSteps = timeSeriesLength / inLength; //each of length inLength graph.rnnClearPreviousState(); for (int j = 0; j < nSteps; j++) { int startTimeRange = j * inLength; int endTimeRange = startTimeRange + inLength; INDArray inputSubset = input.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeRange, endTimeRange)); if (inLength > 1) assertTrue(inputSubset.size(2) == inLength); INDArray[] outArr = graph.rnnTimeStep(inputSubset); assertEquals(1, outArr.length); INDArray out = outArr[0]; INDArray expOutSubset; if (inLength == 1) { val sizes = new long[] {fullOutL3.size(0), fullOutL3.size(1), 1}; expOutSubset = Nd4j.create(DataType.FLOAT, sizes); expOutSubset.tensorAlongDimension(0, 1, 0).assign(fullOutL3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(startTimeRange))); } else { expOutSubset = fullOutL3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeRange, endTimeRange)); } assertEquals(expOutSubset, out); Map<String, INDArray> currL0State = graph.rnnGetPreviousState("0"); Map<String, INDArray> currL1State = graph.rnnGetPreviousState("1"); INDArray lastActL0 = currL0State.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION); INDArray lastActL1 = currL1State.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION); INDArray expLastActL0 = fullOutL0.tensorAlongDimension(endTimeRange - 1, 1, 0); INDArray expLastActL1 = fullOutL1.tensorAlongDimension(endTimeRange - 1, 1, 0); assertEquals(expLastActL0, lastActL0); assertEquals(expLastActL1, lastActL1); } } }
Example 17
Source File: TestSimpleRnn.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testSimpleRnn(){ Nd4j.getRandom().setSeed(12345); int m = 3; int nIn = 5; int layerSize = 6; int tsLength = 7; INDArray in; if (rnnDataFormat == RNNFormat.NCW){ in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength); } else{ in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn); } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) .list() .layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); INDArray out = net.output(in); INDArray w = net.getParam("0_W"); INDArray rw = net.getParam("0_RW"); INDArray b = net.getParam("0_b"); INDArray outLast = null; for( int i=0; i<tsLength; i++ ){ INDArray inCurrent; if (rnnDataFormat == RNNFormat.NCW){ inCurrent = in.get(all(), all(), point(i)); } else{ inCurrent = in.get(all(), point(i), all()); } INDArray outExpCurrent = inCurrent.mmul(w); if(outLast != null){ outExpCurrent.addi(outLast.mmul(rw)); } outExpCurrent.addiRowVector(b); Transforms.tanh(outExpCurrent, false); INDArray outActCurrent; if (rnnDataFormat == RNNFormat.NCW){ outActCurrent = out.get(all(), all(), point(i)); } else{ outActCurrent = out.get(all(), point(i), all()); } assertEquals(String.valueOf(i), outExpCurrent, outActCurrent); outLast = outExpCurrent; } TestUtils.testModelSerialization(net); }
Example 18
Source File: NDArrayCreationUtil.java From deeplearning4j with Apache License 2.0 | 4 votes |
public static List<Pair<INDArray, String>> get5dSubArraysWithShape(int seed, int[] shape, DataType dataType) { List<Pair<INDArray, String>> list = new ArrayList<>(); String baseMsg = "get5dSubArraysWithShape(" + seed + "," + Arrays.toString(shape) + ")"; //Create and return various sub arrays: Nd4j.getRandom().setSeed(seed); int[] newShape1 = Arrays.copyOf(shape, shape.length); newShape1[0] += 5; INDArray temp1 = Nd4j.rand(dataType, newShape1); INDArray subset1 = temp1.get(NDArrayIndex.interval(2, shape[0] + 2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()); list.add(new Pair<>(subset1, baseMsg + ".get(0)")); int[] newShape2 = Arrays.copyOf(shape, shape.length); newShape2[1] += 5; INDArray temp2 = Nd4j.rand(dataType, newShape2); INDArray subset2 = temp2.get(NDArrayIndex.all(), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()); list.add(new Pair<>(subset2, baseMsg + ".get(1)")); int[] newShape3 = Arrays.copyOf(shape, shape.length); newShape3[2] += 5; INDArray temp3 = Nd4j.rand(dataType, newShape3); INDArray subset3 = temp3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, shape[2] + 4), NDArrayIndex.all(), NDArrayIndex.all()); list.add(new Pair<>(subset3, baseMsg + ".get(2)")); int[] newShape4 = Arrays.copyOf(shape, shape.length); newShape4[3] += 5; INDArray temp4 = Nd4j.rand(dataType, newShape4); INDArray subset4 = temp4.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(3, shape[3] + 3), NDArrayIndex.all()); list.add(new Pair<>(subset4, baseMsg + ".get(3)")); int[] newShape5 = Arrays.copyOf(shape, shape.length); newShape5[4] += 5; INDArray temp5 = Nd4j.rand(dataType, newShape5); INDArray subset5 = temp5.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(3, shape[4] + 3)); list.add(new Pair<>(subset5, baseMsg + ".get(4)")); int[] newShape6 = Arrays.copyOf(shape, shape.length); newShape6[0] += 5; newShape6[1] += 5; newShape6[2] += 5; newShape6[3] += 5; newShape6[4] += 5; INDArray temp6 = Nd4j.rand(dataType, newShape6); INDArray subset6 = temp6.get(NDArrayIndex.interval(4, shape[0] + 4), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.interval(2, shape[2] + 2), NDArrayIndex.interval(1, shape[3] + 1), NDArrayIndex.interval(2, shape[4] + 2)); list.add(new Pair<>(subset6, baseMsg + ".get(5)")); return list; }
Example 19
Source File: IndexingTests.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testGet() { System.out.println("Testing sub-array put and get with a 3D array ..."); INDArray arr = Nd4j.linspace(0, 124, 125).reshape(5, 5, 5); /* * Extract elements with the following indices: * * (2,1,1) (2,1,2) (2,1,3) * (2,2,1) (2,2,2) (2,2,3) * (2,3,1) (2,3,2) (2,3,3) */ int slice = 2; int iStart = 1; int jStart = 1; int iEnd = 4; int jEnd = 4; // Method A: Element-wise. INDArray subArr_A = Nd4j.create(new int[] {3, 3}); for (int i = iStart; i < iEnd; i++) { for (int j = jStart; j < jEnd; j++) { double val = arr.getDouble(slice, i, j); int[] sub = new int[] {i - iStart, j - jStart}; subArr_A.putScalar(sub, val); } } // Method B: Using NDArray get and put with index classes. INDArray subArr_B = Nd4j.create(new int[] {3, 3}); INDArrayIndex ndi_Slice = NDArrayIndex.point(slice); INDArrayIndex ndi_J = NDArrayIndex.interval(jStart, jEnd); INDArrayIndex ndi_I = NDArrayIndex.interval(iStart, iEnd); INDArrayIndex[] whereToGet = new INDArrayIndex[] {ndi_Slice, ndi_I, ndi_J}; INDArray whatToPut = arr.get(whereToGet); assertEquals(subArr_A, whatToPut); System.out.println(whatToPut); INDArrayIndex[] whereToPut = new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all()}; subArr_B.put(whereToPut, whatToPut); assertEquals(subArr_A, subArr_B); System.out.println("... done"); }
Example 20
Source File: NDArrayCreationUtil.java From nd4j with Apache License 2.0 | 4 votes |
public static List<Pair<INDArray, String>> get4dSubArraysWithShape(int seed, int... shape) { List<Pair<INDArray, String>> list = new ArrayList<>(); String baseMsg = "get4dSubArraysWithShape(" + seed + "," + Arrays.toString(shape) + ")"; //Create and return various sub arrays: Nd4j.getRandom().setSeed(seed); int[] newShape1 = Arrays.copyOf(shape, shape.length); newShape1[0] += 5; int len = ArrayUtil.prod(newShape1); INDArray temp1 = Nd4j.linspace(1, len, len).reshape(ArrayUtil.toLongArray(newShape1)); INDArray subset1 = temp1.get(NDArrayIndex.interval(2, shape[0] + 2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()); list.add(new Pair<>(subset1, baseMsg + ".get(0)")); int[] newShape2 = Arrays.copyOf(shape, shape.length); newShape2[1] += 5; int len2 = ArrayUtil.prod(newShape2); INDArray temp2 = Nd4j.linspace(1, len2, len2).reshape(ArrayUtil.toLongArray(newShape2)); INDArray subset2 = temp2.get(NDArrayIndex.all(), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.all(), NDArrayIndex.all()); list.add(new Pair<>(subset2, baseMsg + ".get(1)")); int[] newShape3 = Arrays.copyOf(shape, shape.length); newShape3[2] += 5; int len3 = ArrayUtil.prod(newShape3); INDArray temp3 = Nd4j.linspace(1, len3, len3).reshape(ArrayUtil.toLongArray(newShape3)); INDArray subset3 = temp3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, shape[2] + 4), NDArrayIndex.all()); list.add(new Pair<>(subset3, baseMsg + ".get(2)")); int[] newShape4 = Arrays.copyOf(shape, shape.length); newShape4[3] += 5; int len4 = ArrayUtil.prod(newShape4); INDArray temp4 = Nd4j.linspace(1, len4, len4).reshape(ArrayUtil.toLongArray(newShape4)); INDArray subset4 = temp4.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(3, shape[3] + 3)); list.add(new Pair<>(subset4, baseMsg + ".get(3)")); int[] newShape5 = Arrays.copyOf(shape, shape.length); newShape5[0] += 5; newShape5[1] += 5; newShape5[2] += 5; newShape5[3] += 5; int len5 = ArrayUtil.prod(newShape5); INDArray temp5 = Nd4j.linspace(1, len5, len5).reshape(ArrayUtil.toLongArray(newShape5)); INDArray subset5 = temp5.get(NDArrayIndex.interval(4, shape[0] + 4), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.interval(2, shape[2] + 2), NDArrayIndex.interval(1, shape[3] + 1)); list.add(new Pair<>(subset5, baseMsg + ".get(4)")); return list; }