Java Code Examples for org.nd4j.linalg.api.shape.Shape#rankFromShape()
The following examples show how to use
org.nd4j.linalg.api.shape.Shape#rankFromShape() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ClipByNorm.java From nd4j with Apache License 2.0 | 6 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> grad) { //dOut/dIn is ??? if clipped, 1 otherwise int origRank = Shape.rankFromShape(arg().getShape()); SDVariable l2norm = f().norm2(arg(), dimensions); SDVariable broadcastableNorm = f().reductionBroadcastableWithOrigShape(origRank, dimensions, l2norm); SDVariable isClippedBC = f().gte(broadcastableNorm, clipValue); SDVariable notClippedBC = isClippedBC.rsub(1.0); // SDVariable dnormdx = arg().div(broadcastableNorm); // SDVariable sqNorm = f().square(broadcastableNorm); // SDVariable dOutdInClipped = sqNorm.rdiv(-1).mul(dnormdx).mul(arg()) //-1/(norm2(x))^2 * x/norm2(x) // .add(broadcastableNorm.rdiv(1.0)) // .mul(clipValue); SDVariable dOutdInClipped = f().neg(f().square(arg()).div(f().cube(broadcastableNorm))) //-x^2/(norm2(x))^3 .add(broadcastableNorm.rdiv(1.0)) //+ 1/norm(x) .mul(clipValue).mul(isClippedBC); SDVariable ret = notClippedBC.add(dOutdInClipped).mul(grad.get(0)); return Arrays.asList(ret); }
Example 2
Source File: EuclideanDistance.java From nd4j with Apache License 2.0 | 6 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //ddist(x,y)/dxi = (xi-yi)/dist(x,y) SDVariable euc = outputVariables()[0]; SDVariable difference = larg().sub(rarg()); SDVariable divBroadcastable; int origRank = Shape.rankFromShape(arg().getShape()); //TODO shape may not always be defined? if(!(dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE) ){ //1x1 output case divBroadcastable = i_v1.get(0).div(euc); } else { divBroadcastable = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0).div(euc)); } SDVariable gradX = difference.mul(divBroadcastable); SDVariable gradY = f().neg(gradX); return Arrays.asList(gradX, gradY); }
Example 3
Source File: JaccardDistance.java From nd4j with Apache License 2.0 | 6 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> f1) { //Jaccard distance: https://en.wikipedia.org/wiki/Jaccard_index#Generalized_Jaccard_similarity_and_distance //J(x,y) = 1 - sum_i min(x_i, y_i) / sum_i max(x_i, y_i) int rank = Shape.rankFromShape(larg().getShape()); SDVariable jSim = outputVariables()[0].rsub(1.0); //jaccard similarity = 1 - jaccard distance SDVariable min = f().min(larg(), rarg()); SDVariable max = f().max(larg(), rarg()); SDVariable sumMax = f().sum(max, dimensions); SDVariable broadcastableSumMax = f().reductionBroadcastableWithOrigShape(rank, dimensions, sumMax); SDVariable broadcastableJSim = f().reductionBroadcastableWithOrigShape(rank, dimensions, jSim); SDVariable xIsMin = f().eq(min, larg()); SDVariable xIsMax = f().eq(max, larg()); SDVariable yIsMin = f().eq(min, rarg()); SDVariable yIsMax = f().eq(max, rarg()); SDVariable dldx = xIsMax.mul(broadcastableJSim).sub(xIsMin).div(broadcastableSumMax); SDVariable dldy = yIsMax.mul(broadcastableJSim).sub(yIsMin).div(broadcastableSumMax); return Arrays.asList(dldx.mul(f1.get(0)), dldy.mul(f1.get(0))); }
Example 4
Source File: ManhattanDistance.java From nd4j with Apache License 2.0 | 6 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //ddist(x,y)/dxi = sign(xi-yi) SDVariable difference = larg().sub(rarg()); SDVariable gradBroadcastable; int origRank = Shape.rankFromShape(arg().getShape()); //TODO shape may not always be defined? if(!(dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE) ){ //1x1 output case gradBroadcastable = i_v1.get(0); } else { gradBroadcastable = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); } SDVariable gradX = sameDiff.sign(difference).mul(gradBroadcastable); SDVariable gradY = f().neg(gradX); return Arrays.asList(gradX, gradY); }
Example 5
Source File: CosineSimilarity.java From nd4j with Apache License 2.0 | 6 votes |
public static List<SDVariable> doDiff(SameDiff sameDiff, DifferentialFunctionFactory f, SDVariable x, SDVariable y, SDVariable gradOut, int... dimensions){ SDVariable a = sameDiff.sum(x.mul(y),dimensions); SDVariable l2x = f.norm2(x, dimensions); SDVariable l2y = f.norm2(y, dimensions); SDVariable b = l2x.mul(l2y); int origRank = Shape.rankFromShape(x.getShape()); SDVariable broadcastableA = f.reductionBroadcastableWithOrigShape(origRank, dimensions, a); SDVariable broadcastableB = f.reductionBroadcastableWithOrigShape(origRank, dimensions, b); SDVariable broadcastableL2xSq = f.reductionBroadcastableWithOrigShape(origRank, dimensions, sameDiff.square(l2x)); SDVariable broadcastableL2ySq = f.reductionBroadcastableWithOrigShape(origRank, dimensions, sameDiff.square(l2y)); SDVariable broadcastableGrad = f.reductionBroadcastableWithOrigShape(origRank, dimensions, gradOut); SDVariable dcdx = y.sub(x.mul(broadcastableA).div(broadcastableL2xSq)).div(broadcastableB); SDVariable dcdy = x.sub(y.mul(broadcastableA).div(broadcastableL2ySq)).div(broadcastableB); return Arrays.asList(dcdx.mul(broadcastableGrad), dcdy.mul(broadcastableGrad)); }
Example 6
Source File: Min.java From nd4j with Apache License 2.0 | 6 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //TODO do we need to handle the "multiple equal minimums" case? //TODO code duplication (min/max) SDVariable out = outputVariables()[0]; int origRank = Shape.rankFromShape(arg().getShape()); SDVariable expandedOut = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, out); expandedOut = sameDiff.onesLike("temp0", arg()).mul("tempmul", expandedOut); SDVariable expandedGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable eq = sameDiff.eq(arg(), expandedOut); SDVariable ret = eq.mul(expandedGrad); return Arrays.asList(ret); }
Example 7
Source File: StandardDeviation.java From nd4j with Apache License 2.0 | 6 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //Here: calculating dL/dIn given dL/dOut (i.e., i_v1) and input/output //If out = stdev(in) then: //dL/dIn = dL/dOut * dOut/dIn //dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) int origRank = Shape.rankFromShape(arg().getShape()); long n = f().getReductionLength(this); SDVariable broadcastableStdevOut = f().reductionBroadcastableWithOrigShape(origRank, dimensions, outputVariables()[0]); SDVariable broadcastableMean = f().reductionBroadcastableWithOrigShape(origRank, dimensions, f().mean(arg(), dimensions)); SDVariable diff = arg().sub(broadcastableMean); SDVariable dOutdIn = diff.div(broadcastableStdevOut); if (this.biasCorrected) { dOutdIn = dOutdIn.div(n - 1); } else { dOutdIn = dOutdIn.div(n); } SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable dLdIn = dOutdIn.mul(broadcastableGrad); return Arrays.asList(dLdIn); }
Example 8
Source File: NormMax.java From nd4j with Apache License 2.0 | 6 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //maxnorm(in) = max_i |x_i| //d maxnorm(in)/dx = 0 if x_i is not the max, or d|x|/dx otherwise SDVariable absIn = sameDiff.abs(arg()); SDVariable maxnorm = outputVariables()[0]; int origRank = Shape.rankFromShape(arg().getShape()); //TODO shape may not always be defined? SDVariable maxnormBc = f().reductionBroadcastableWithOrigShape(origRank, dimensions, maxnorm); maxnormBc = sameDiff.onesLike(arg()).mul(maxnormBc); SDVariable eq = sameDiff.eq(absIn, maxnormBc); SDVariable dAbsXdX = sameDiff.sign(arg()); SDVariable dNormmaxDx = eq.mul(dAbsXdX); SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable ret = dNormmaxDx.mul(broadcastableGrad); return Arrays.asList(ret); }
Example 9
Source File: Norm2.java From nd4j with Apache License 2.0 | 5 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //d norm2(in)/dx = x / norm2(in) SDVariable norm2 = outputVariables()[0]; int origRank = Shape.rankFromShape(arg().getShape()); //TODO shape may not always be defined? SDVariable broadcastableNorm2 = f().reductionBroadcastableWithOrigShape(origRank, dimensions, norm2); SDVariable broadcastableGradOut = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable ret = arg().div(broadcastableNorm2).mul(broadcastableGradOut); return Arrays.asList(ret); }
Example 10
Source File: Variance.java From nd4j with Apache License 2.0 | 5 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //If out = var(in) then: //dL/dIn = dL/dOut * dOut/dIn // with dOut/dIn = (in-mean) * 2/(n-1) val n = f().getReductionLength(this); int origRank = Shape.rankFromShape(arg().getShape()); SDVariable broadcastableMean = f().reductionBroadcastableWithOrigShape(origRank, dimensions, f().mean(arg(), dimensions)); SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable dOutdIn = arg().sub(broadcastableMean).mul(2.0 / (biasCorrected ? (n - 1) : n)); SDVariable dLdIn = dOutdIn.mul(broadcastableGrad); return Arrays.asList(dLdIn); }
Example 11
Source File: Mean.java From nd4j with Apache License 2.0 | 5 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //If out = mean(in), then dL/dIn = 1/N * dL/dOut (broadcast to appropriate shape) //Note that N differs for "along dimension" vs. "whole array" reduce cases long n = f().getReductionLength(this); int rank = Shape.rankFromShape(arg().getShape()); SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(rank, dimensions, i_v1.get(0)); SDVariable ret = sameDiff.onesLike(arg()).div(n); //1/N with shape equal to input ret = ret.mul(broadcastableGrad); return Arrays.asList(ret); }
Example 12
Source File: Sum.java From nd4j with Apache License 2.0 | 5 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //Out = sum(in) // dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * 1 // But broadcast to shape of the input int origRank = Shape.rankFromShape(arg().getShape()); //TODO shape may not always be defined? SDVariable broadcastable = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable ret = sameDiff.onesLike(arg()).mul(broadcastable); return Arrays.asList(ret); }
Example 13
Source File: Norm1.java From nd4j with Apache License 2.0 | 5 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //d l1Norm(in)/dx = signum(x) SDVariable signum = sameDiff.sign(arg()); //Note that we need to expand the dimensions of the gradient - auto-broadcast won't work for all cases. int origRank = Shape.rankFromShape(arg().getShape()); //TODO shape may not always be defined? SDVariable bcGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); return Arrays.asList(signum.mul(bcGrad)); }
Example 14
Source File: Prod.java From nd4j with Apache License 2.0 | 5 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { SDVariable prod = outputVariables()[0]; int origRank = Shape.rankFromShape(arg().getShape()); //TODO shape may not always be defined? SDVariable broadcastableGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable broadcastableProd = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, prod); SDVariable mul = broadcastableGrad.div(arg()); SDVariable ret = broadcastableProd.mul(mul); return Arrays.asList(ret); }
Example 15
Source File: Max.java From nd4j with Apache License 2.0 | 5 votes |
@Override public List<SDVariable> doDiff(List<SDVariable> i_v1) { //TODO do we need to handle the "multiple equal maximums" case? //TODO code duplication (min/max) SDVariable out = outputVariables()[0]; int origRank = Shape.rankFromShape(arg().getShape()); SDVariable expandedOut = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, out); expandedOut = sameDiff.onesLike(arg()).mul(expandedOut); SDVariable expandedGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0)); SDVariable eq = sameDiff.eq(arg(), expandedOut); SDVariable ret = eq.mul(expandedGrad); return Arrays.asList(ret); }