Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#like()
The following examples show how to use
org.nd4j.linalg.api.ndarray.INDArray#like() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MiscOpValidation.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testSegmentProdBpSimple(){ INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT); INDArray data = Nd4j.create(new double[]{5,1,7,2,3,4,1,3}, new long[]{8}); INDArray grad = Nd4j.createFromArray(1.0,2.0,3.0,4.0); int numSegments = 4; INDArray gradData = data.like(); INDArray gradIdxs = segmentIdxs.like(); DynamicCustomOp op = DynamicCustomOp.builder("unsorted_segment_prod_bp") .addInputs(data,segmentIdxs,grad) .addIntegerArguments(numSegments) .addOutputs(gradData, gradIdxs) .build(); Nd4j.getExecutioner().exec(op); }
Example 2
Source File: ShapeOpValidation.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testSegmentMean(){ INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3); INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2); INDArray out = Nd4j.create(DataType.FLOAT, 3, 3); Nd4j.exec(DynamicCustomOp.builder("segment_mean") .addInputs(x, segmentIds) .addOutputs(out) .build()); INDArray exp = out.like(); exp.putRow(0, x.getRow(0).add(x.getRow(1)).muli(0.5)); exp.putRow(1, x.getRow(2).add(x.getRow(3)).muli(0.5)); exp.putRow(2, x.getRow(4).add(x.getRow(5)).muli(0.5)); assertEquals(exp, out); }
Example 3
Source File: MiscOpValidation.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testMmulRank4() throws Exception { Nd4j.getRandom().setSeed(12345); INDArray arr1 = Nd4j.rand(DataType.FLOAT, 32, 12, 128, 64); INDArray arr2 = Nd4j.rand(DataType.FLOAT, 32, 12, 128, 64); DynamicCustomOp op = DynamicCustomOp.builder("matmul") .addInputs(arr1, arr2) .addIntegerArguments(0, 1) //Transpose arr2 only .build(); List<LongShapeDescriptor> shapes = op.calculateOutputShape(); assertEquals(1, shapes.size()); long[] shape = new long[]{32,12,128,128}; assertArrayEquals(shape, shapes.get(0).getShape()); INDArray out = Nd4j.create(DataType.FLOAT, shape); INDArray outExp = out.like(); for( int i=0; i<32; i++ ){ for( int j=0; j<12; j++ ){ INDArray sub1 = arr1.get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all(), NDArrayIndex.all()); INDArray sub2 = arr2.get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all(), NDArrayIndex.all()); INDArray mmul = sub1.mmul(sub2.transpose()); outExp.get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all(), NDArrayIndex.all()).assign(mmul); } } op.setOutputArgument(0, out); Nd4j.exec(op); assertEquals(outExp, out); }
Example 4
Source File: RandomTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testGaussianDistribution1() { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); INDArray z1 = Nd4j.create(DataType.DOUBLE, 1000000); INDArray z2 = Nd4j.create(DataType.DOUBLE, 1000000); INDArray zDup = z1.like(); GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0); Nd4j.getExecutioner().exec(op1, random1); GaussianDistribution op2 = new GaussianDistribution(z2, 0.0, 1.0); Nd4j.getExecutioner().exec(op2, random2); assertNotEquals(zDup, z1); assertEquals(0.0, z1.meanNumber().doubleValue(), 0.01); assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01); val d1 = z1.toDoubleVector(); val d2 = z2.toDoubleVector(); assertArrayEquals(d1, d2, 1e-4); assertEquals(z1, z2); }
Example 5
Source File: CustomOpsTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void isMax4d_2dims(){ Nd4j.getRandom().setSeed(12345); INDArray in = Nd4j.rand(DataType.FLOAT, 3, 3, 4, 4).permute(0, 2, 3, 1); INDArray out_permutedIn = in.like(); INDArray out_dupedIn = in.like(); Nd4j.exec(new IsMax(in.dup(), out_dupedIn, 2, 3)); Nd4j.exec(new IsMax(in, out_permutedIn, 2, 3)); assertEquals(out_dupedIn, out_permutedIn); }
Example 6
Source File: CustomOpsTests.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testBatchNormBpNHWC(){ //Nd4j.getEnvironment().allowHelpers(false); //Passes if helpers/MKLDNN is disabled INDArray in = Nd4j.rand(DataType.FLOAT, 2, 4, 4, 3); INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape()); INDArray epsStrided = eps.permute(1,0,2,3).dup().permute(1,0,2,3); INDArray mean = Nd4j.rand(DataType.FLOAT, 3); INDArray var = Nd4j.rand(DataType.FLOAT, 3); INDArray gamma = Nd4j.rand(DataType.FLOAT, 3); INDArray beta = Nd4j.rand(DataType.FLOAT, 3); assertEquals(eps, epsStrided); INDArray out1eps = in.like(); INDArray out1m = mean.like(); INDArray out1v = var.like(); INDArray out2eps = in.like(); INDArray out2m = mean.like(); INDArray out2v = var.like(); DynamicCustomOp op1 = DynamicCustomOp.builder("batchnorm_bp") .addInputs(in, mean, var, gamma, beta, eps) .addOutputs(out1eps, out1m, out1v) .addIntegerArguments(1, 1, 3) .addFloatingPointArguments(1e-5) .build(); DynamicCustomOp op2 = DynamicCustomOp.builder("batchnorm_bp") .addInputs(in, mean, var, gamma, beta, epsStrided) .addOutputs(out2eps, out2m, out2v) .addIntegerArguments(1, 1, 3) .addFloatingPointArguments(1e-5) .build(); Nd4j.exec(op1); Nd4j.exec(op2); assertEquals(out1eps, out2eps); //Fails here assertEquals(out1m, out2m); assertEquals(out1v, out2v); }
Example 7
Source File: TestBertIterator.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test(timeout = 20000L) public void testBertSequenceClassification() throws Exception { int minibatchSize = 2; TestSentenceHelper testHelper = new TestSentenceHelper(); BertIterator b = BertIterator.builder() .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) .minibatchSize(minibatchSize) .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); MultiDataSet mds = b.next(); assertEquals(1, mds.getFeatures().length); System.out.println(mds.getFeatures(0)); System.out.println(mds.getFeaturesMaskArray(0)); INDArray expF = Nd4j.create(DataType.INT, 1, 16); INDArray expM = Nd4j.create(DataType.INT, 1, 16); Map<String, Integer> m = testHelper.getTokenizer().getVocab(); for (int i = 0; i < minibatchSize; i++) { INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); List<String> tokens = testHelper.getTokenizedSentences().get(i); System.out.println(tokens); for (int j = 0; j < tokens.size(); j++) { String token = tokens.get(j); if (!m.containsKey(token)) { throw new IllegalStateException("Unknown token: \"" + token + "\""); } int idx = m.get(token); expFTemp.putScalar(0, j, idx); expMTemp.putScalar(0, j, 1); } if (i == 0) { expF = expFTemp.dup(); expM = expMTemp.dup(); } else { expF = Nd4j.vstack(expF, expFTemp); expM = Nd4j.vstack(expM, expMTemp); } } assertEquals(expF, mds.getFeatures(0)); assertEquals(expM, mds.getFeaturesMaskArray(0)); assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); assertFalse(b.hasNext()); b.reset(); assertTrue(b.hasNext()); //Same thing, but with segment ID also b = BertIterator.builder() .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) .minibatchSize(minibatchSize) .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); mds = b.next(); assertEquals(2, mds.getFeatures().length); //Segment ID should be all 0s for single segment task INDArray segmentId = expM.like(); assertEquals(segmentId, mds.getFeatures(1)); assertEquals(segmentId, b.featurizeSentences(testHelper.getSentences()).getFirst()[1]); }