Java Code Examples for org.nd4j.linalg.ops.transforms.Transforms#tanh()
The following examples show how to use
org.nd4j.linalg.ops.transforms.Transforms#tanh() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: FailingSameDiffTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testExecutionDifferentShapesTransform(){ OpValidationSuite.ignoreFailing(); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4)); SDVariable tanh = sd.math().tanh(in); INDArray exp = Transforms.tanh(in.getArr(), true); INDArray out = tanh.eval(); assertEquals(exp, out); //Now, replace with minibatch 5: in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4)); INDArray out2 = tanh.eval(); assertArrayEquals(new long[]{5,4}, out2.shape()); exp = Transforms.tanh(in.getArr(), true); assertEquals(exp, out2); }
Example 2
Source File: SameDiffTests.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testActivationBackprop() { Activation[] afns = new Activation[]{ Activation.TANH, Activation.SIGMOID, Activation.ELU, Activation.SOFTPLUS, Activation.SOFTSIGN, Activation.HARDTANH, Activation.CUBE, //WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426 Activation.RELU, //JVM crash Activation.LEAKYRELU //JVM crash }; for (Activation a : afns) { SameDiff sd = SameDiff.create(); INDArray inArr = Nd4j.linspace(-3, 3, 7); INDArray labelArr = Nd4j.linspace(-3, 3, 7).muli(0.5); SDVariable in = sd.var("in", inArr.dup()); // System.out.println("inArr: " + inArr); INDArray outExp; SDVariable out; switch (a) { case ELU: out = sd.elu("out", in); outExp = Transforms.elu(inArr, true); break; case HARDTANH: out = sd.hardTanh("out", in); outExp = Transforms.hardTanh(inArr, true); break; case LEAKYRELU: out = sd.leakyRelu("out", in, 0.01); outExp = Transforms.leakyRelu(inArr, true); break; case RELU: out = sd.relu("out", in, 0.0); outExp = Transforms.relu(inArr, true); break; case SIGMOID: out = sd.sigmoid("out", in); outExp = Transforms.sigmoid(inArr, true); break; case SOFTPLUS: out = sd.softplus("out", in); outExp = Transforms.softPlus(inArr, true); break; case SOFTSIGN: out = sd.softsign("out", in); outExp = Transforms.softsign(inArr, true); break; case TANH: out = sd.tanh("out", in); outExp = Transforms.tanh(inArr, true); break; case CUBE: out = sd.cube("out", in); outExp = Transforms.pow(inArr, 3, true); break; default: throw new RuntimeException(a.toString()); } //Sum squared error loss: SDVariable label = sd.var("label", labelArr.dup()); SDVariable diff = label.sub("diff", out); SDVariable sqDiff = diff.mul("sqDiff", diff); SDVariable totSum = sd.sum("totSum", sqDiff, Integer.MAX_VALUE); //Loss function... sd.exec(); INDArray outAct = sd.getVariable("out").getArr(); assertEquals(a.toString(), outExp, outAct); // L = sum_i (label - out)^2 //dL/dOut = 2(out - label) INDArray dLdOutExp = outExp.sub(labelArr).mul(2); INDArray dLdInExp = a.getActivationFunction().backprop(inArr.dup(), dLdOutExp.dup()).getFirst(); sd.execBackwards(); SameDiff gradFn = sd.getFunction("grad"); INDArray dLdOutAct = gradFn.getVariable("out-grad").getArr(); INDArray dLdInAct = gradFn.getVariable("in-grad").getArr(); assertEquals(a.toString(), dLdOutExp, dLdOutAct); assertEquals(a.toString(), dLdInExp, dLdInAct); } }
Example 3
Source File: SameDiffTests.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testActivationBackprop() { Activation[] afns = new Activation[]{ Activation.TANH, Activation.SIGMOID, Activation.ELU, Activation.SOFTPLUS, Activation.SOFTSIGN, Activation.HARDTANH, Activation.CUBE, //WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426 Activation.RELU, //JVM crash Activation.LEAKYRELU //JVM crash }; for (Activation a : afns) { SameDiff sd = SameDiff.create(); INDArray inArr = Nd4j.linspace(-3, 3, 7); INDArray labelArr = Nd4j.linspace(-3, 3, 7).muli(0.5); SDVariable in = sd.var("in", inArr.dup()); // System.out.println("inArr: " + inArr); INDArray outExp; SDVariable out; switch (a) { case ELU: out = sd.nn().elu("out", in); outExp = Transforms.elu(inArr, true); break; case HARDTANH: out = sd.nn().hardTanh("out", in); outExp = Transforms.hardTanh(inArr, true); break; case LEAKYRELU: out = sd.nn().leakyRelu("out", in, 0.01); outExp = Transforms.leakyRelu(inArr, true); break; case RELU: out = sd.nn().relu("out", in, 0.0); outExp = Transforms.relu(inArr, true); break; case SIGMOID: out = sd.nn().sigmoid("out", in); outExp = Transforms.sigmoid(inArr, true); break; case SOFTPLUS: out = sd.nn().softplus("out", in); outExp = Transforms.softPlus(inArr, true); break; case SOFTSIGN: out = sd.nn().softsign("out", in); outExp = Transforms.softsign(inArr, true); break; case TANH: out = sd.math().tanh("out", in); outExp = Transforms.tanh(inArr, true); break; case CUBE: out = sd.math().cube("out", in); outExp = Transforms.pow(inArr, 3, true); break; default: throw new RuntimeException(a.toString()); } //Sum squared error loss: SDVariable label = sd.var("label", labelArr.dup()); SDVariable diff = label.sub("diff", out); SDVariable sqDiff = diff.mul("sqDiff", diff); SDVariable totSum = sd.sum("totSum", sqDiff, Integer.MAX_VALUE); //Loss function... Map<String,INDArray> m = sd.output(Collections.emptyMap(), "out"); INDArray outAct = m.get("out"); assertEquals(a.toString(), outExp, outAct); // L = sum_i (label - out)^2 //dL/dOut = 2(out - label) INDArray dLdOutExp = outExp.sub(labelArr).mul(2); INDArray dLdInExp = a.getActivationFunction().backprop(inArr.dup(), dLdOutExp.dup()).getFirst(); Map<String,INDArray> grads = sd.calculateGradients(null, "out", "in"); // sd.execBackwards(Collections.emptyMap()); // SameDiff gradFn = sd.getFunction("grad"); INDArray dLdOutAct = grads.get("out"); INDArray dLdInAct = grads.get("in"); assertEquals(a.toString(), dLdOutExp, dLdOutAct); assertEquals(a.toString(), dLdInExp, dLdInAct); } }
Example 4
Source File: RnnOpValidation.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testRnnBlockCell(){ Nd4j.getRandom().setSeed(12345); int mb = 2; int nIn = 3; int nOut = 4; SameDiff sd = SameDiff.create(); SDVariable x = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nIn)); SDVariable cLast = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nOut)); SDVariable yLast = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nOut)); SDVariable W = sd.constant(Nd4j.rand(DataType.FLOAT, (nIn+nOut), 4*nOut)); SDVariable Wci = sd.constant(Nd4j.rand(DataType.FLOAT, nOut)); SDVariable Wcf = sd.constant(Nd4j.rand(DataType.FLOAT, nOut)); SDVariable Wco = sd.constant(Nd4j.rand(DataType.FLOAT, nOut)); SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut)); double fb = 1.0; LSTMConfiguration conf = LSTMConfiguration.builder() .peepHole(true) .forgetBias(fb) .clippingCellValue(0.0) .build(); LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b) .inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build(); LSTMCellOutputs v = new LSTMCellOutputs(sd.rnn().lstmCell(x, cLast, yLast, weights, conf)); //Output order: i, c, f, o, z, h, y List<String> toExec = new ArrayList<>(); for(SDVariable sdv : v.getAllOutputs()){ toExec.add(sdv.name()); } //Test forward pass: Map<String,INDArray> m = sd.output(null, toExec); //Weights and bias order: [i, f, z, o] //Block input (z) - post tanh: INDArray wz_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(nOut, 2*nOut)); //Input weights INDArray wz_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(nOut, 2*nOut)); //Recurrent weights INDArray bz = b.getArr().get(NDArrayIndex.interval(nOut, 2*nOut)); INDArray zExp = x.getArr().mmul(wz_x).addiRowVector(bz); //[mb,nIn]*[nIn, nOut] + [nOut] zExp.addi(yLast.getArr().mmul(wz_r)); //[mb,nOut]*[nOut,nOut] Transforms.tanh(zExp, false); INDArray zAct = m.get(toExec.get(4)); assertEquals(zExp, zAct); //Input modulation gate (post sigmoid) - i: (note: peephole input - last time step) INDArray wi_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(0, nOut)); //Input weights INDArray wi_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(0, nOut)); //Recurrent weights INDArray bi = b.getArr().get(NDArrayIndex.interval(0, nOut)); INDArray iExp = x.getArr().mmul(wi_x).addiRowVector(bi); //[mb,nIn]*[nIn, nOut] + [nOut] iExp.addi(yLast.getArr().mmul(wi_r)); //[mb,nOut]*[nOut,nOut] iExp.addi(cLast.getArr().mulRowVector(Wci.getArr())); //Peephole Transforms.sigmoid(iExp, false); assertEquals(iExp, m.get(toExec.get(0))); //Forget gate (post sigmoid): (note: peephole input - last time step) INDArray wf_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(2*nOut, 3*nOut)); //Input weights INDArray wf_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(2*nOut, 3*nOut)); //Recurrent weights INDArray bf = b.getArr().get(NDArrayIndex.interval(2*nOut, 3*nOut)); INDArray fExp = x.getArr().mmul(wf_x).addiRowVector(bf); //[mb,nIn]*[nIn, nOut] + [nOut] fExp.addi(yLast.getArr().mmul(wf_r)); //[mb,nOut]*[nOut,nOut] fExp.addi(cLast.getArr().mulRowVector(Wcf.getArr())); //Peephole fExp.addi(fb); Transforms.sigmoid(fExp, false); assertEquals(fExp, m.get(toExec.get(2))); //Cell state (pre tanh): tanh(z) .* sigmoid(i) + sigmoid(f) .* cLast INDArray cExp = zExp.mul(iExp).add(fExp.mul(cLast.getArr())); INDArray cAct = m.get(toExec.get(1)); assertEquals(cExp, cAct); //Output gate (post sigmoid): (note: peephole input: current time step) INDArray wo_x = W.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(3*nOut, 4*nOut)); //Input weights INDArray wo_r = W.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(3*nOut, 4*nOut)); //Recurrent weights INDArray bo = b.getArr().get(NDArrayIndex.interval(3*nOut, 4*nOut)); INDArray oExp = x.getArr().mmul(wo_x).addiRowVector(bo); //[mb,nIn]*[nIn, nOut] + [nOut] oExp.addi(yLast.getArr().mmul(wo_r)); //[mb,nOut]*[nOut,nOut] oExp.addi(cExp.mulRowVector(Wco.getArr())); //Peephole Transforms.sigmoid(oExp, false); assertEquals(oExp, m.get(toExec.get(3))); //Cell state, post tanh INDArray hExp = Transforms.tanh(cExp, true); assertEquals(hExp, m.get(toExec.get(5))); //Final output INDArray yExp = hExp.mul(oExp); assertEquals(yExp, m.get(toExec.get(6))); }
Example 5
Source File: RnnOpValidation.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testGRUCell(){ Nd4j.getRandom().setSeed(12345); int mb = 2; int nIn = 3; int nOut = 4; SameDiff sd = SameDiff.create(); SDVariable x = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nIn)); SDVariable hLast = sd.constant(Nd4j.rand(DataType.FLOAT, mb, nOut)); SDVariable Wru = sd.constant(Nd4j.rand(DataType.FLOAT, (nIn+nOut), 2*nOut)); SDVariable Wc = sd.constant(Nd4j.rand(DataType.FLOAT, (nIn+nOut), nOut)); SDVariable bru = sd.constant(Nd4j.rand(DataType.FLOAT, 2*nOut)); SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut)); double fb = 1.0; GRUWeights weights = GRUWeights.builder() .ruWeight(Wru) .cWeight(Wc) .ruBias(bru) .cBias(bc) .build(); SDVariable[] v = sd.rnn().gruCell(x, hLast, weights); List<String> toExec = new ArrayList<>(); for(SDVariable sdv : v){ toExec.add(sdv.name()); } //Test forward pass: Map<String,INDArray> m = sd.output(null, toExec); //Weights and bias order: [r, u], [c] //Reset gate: INDArray wr_x = Wru.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(0, nOut)); //Input weights INDArray wr_r = Wru.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(0, nOut)); //Recurrent weights INDArray br = bru.getArr().get(NDArrayIndex.interval(0, nOut)); INDArray rExp = x.getArr().mmul(wr_x).addiRowVector(br); //[mb,nIn]*[nIn, nOut] + [nOut] rExp.addi(hLast.getArr().mmul(wr_r)); //[mb,nOut]*[nOut,nOut] Transforms.sigmoid(rExp,false); INDArray rAct = m.get(toExec.get(0)); assertEquals(rExp, rAct); //Update gate: INDArray wu_x = Wru.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.interval(nOut, 2*nOut)); //Input weights INDArray wu_r = Wru.getArr().get(NDArrayIndex.interval(nIn,nIn+nOut), NDArrayIndex.interval(nOut, 2*nOut)); //Recurrent weights INDArray bu = bru.getArr().get(NDArrayIndex.interval(nOut, 2*nOut)); INDArray uExp = x.getArr().mmul(wu_x).addiRowVector(bu); //[mb,nIn]*[nIn, nOut] + [nOut] uExp.addi(hLast.getArr().mmul(wu_r)); //[mb,nOut]*[nOut,nOut] Transforms.sigmoid(uExp,false); INDArray uAct = m.get(toExec.get(1)); assertEquals(uExp, uAct); //c = tanh(x * Wcx + Wcr * (hLast .* r)) INDArray Wcx = Wc.getArr().get(NDArrayIndex.interval(0,nIn), NDArrayIndex.all()); INDArray Wcr = Wc.getArr().get(NDArrayIndex.interval(nIn, nIn+nOut), NDArrayIndex.all()); INDArray cExp = x.getArr().mmul(Wcx); cExp.addi(hLast.getArr().mul(rExp).mmul(Wcr)); cExp.addiRowVector(bc.getArr()); Transforms.tanh(cExp, false); assertEquals(cExp, m.get(toExec.get(2))); //h = u * hLast + (1-u) * c INDArray hExp = uExp.mul(hLast.getArr()).add(uExp.rsub(1.0).mul(cExp)); assertEquals(hExp, m.get(toExec.get(3))); }
Example 6
Source File: TestSimpleRnn.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testSimpleRnn(){ Nd4j.getRandom().setSeed(12345); int m = 3; int nIn = 5; int layerSize = 6; int tsLength = 7; INDArray in; if (rnnDataFormat == RNNFormat.NCW){ in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength); } else{ in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn); } MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) .list() .layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); INDArray out = net.output(in); INDArray w = net.getParam("0_W"); INDArray rw = net.getParam("0_RW"); INDArray b = net.getParam("0_b"); INDArray outLast = null; for( int i=0; i<tsLength; i++ ){ INDArray inCurrent; if (rnnDataFormat == RNNFormat.NCW){ inCurrent = in.get(all(), all(), point(i)); } else{ inCurrent = in.get(all(), point(i), all()); } INDArray outExpCurrent = inCurrent.mmul(w); if(outLast != null){ outExpCurrent.addi(outLast.mmul(rw)); } outExpCurrent.addiRowVector(b); Transforms.tanh(outExpCurrent, false); INDArray outActCurrent; if (rnnDataFormat == RNNFormat.NCW){ outActCurrent = out.get(all(), all(), point(i)); } else{ outActCurrent = out.get(all(), point(i), all()); } assertEquals(String.valueOf(i), outExpCurrent, outActCurrent); outLast = outExpCurrent; } TestUtils.testModelSerialization(net); }