Java Code Examples for org.nd4j.linalg.api.ndarray.INDArray#assign()
The following examples show how to use
org.nd4j.linalg.api.ndarray.INDArray#assign() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: NadamUpdater.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { if (!viewArray.isRowVector()) throw new IllegalArgumentException("Invalid input: expect row vector input"); if (initialize) viewArray.assign(0); long length = viewArray.length(); this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); this.v = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); //Reshape to match the expected shape of the input gradient arrays this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f'); this.v = Shape.newShapeNoCopy(this.v, gradientShape, gradientOrder == 'f'); if (m == null || v == null) throw new IllegalStateException("Could not correctly reshape gradient view arrays"); this.gradientReshapeOrder = gradientOrder; }
Example 2
Source File: SlicingTestsC.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testGetRow() { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray get = arr.getRow(1); INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all()); INDArray assertion = Nd4j.create(new double[] {4, 5, 6}); assertEquals(assertion, get); assertEquals(get, get2); get2.assign(Nd4j.linspace(1, 3, 3, DataType.DOUBLE)); assertEquals(Nd4j.linspace(1, 3, 3, DataType.DOUBLE), get2); INDArray threeByThree = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray offsetTest = threeByThree.get(new SpecifiedIndex(1, 2), NDArrayIndex.all()); INDArray threeByThreeAssertion = Nd4j.create(new double[][] {{4, 5, 6}, {7, 8, 9}}); assertEquals(threeByThreeAssertion, offsetTest); }
Example 3
Source File: NadamUpdater.java From nd4j with Apache License 2.0 | 6 votes |
@Override public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { if (!viewArray.isRowVector()) throw new IllegalArgumentException("Invalid input: expect row vector input"); if (initialize) viewArray.assign(0); long length = viewArray.length(); this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); this.v = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); //Reshape to match the expected shape of the input gradient arrays this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f'); this.v = Shape.newShapeNoCopy(this.v, gradientShape, gradientOrder == 'f'); if (m == null || v == null) throw new IllegalStateException("Could not correctly reshape gradient view arrays"); this.gradientReshapeOrder = gradientOrder; }
Example 4
Source File: EndlessWorkspaceTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void endlessValidation1() { Nd4j.getMemoryManager().togglePeriodicGc(true); AtomicLong counter = new AtomicLong(0); while (true) { INDArray array1 = Nd4j.create(2 * 1024 * 1024); array1.assign(1.0); assertEquals(1.0f, array1.meanNumber().floatValue(), 0.01); if (counter.incrementAndGet() % 1000 == 0) { log.info("{} iterations passed...", counter.get()); System.gc(); } } }
Example 5
Source File: ActivationPReLU.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); INDArray dLdalpha = alpha.ulike(); INDArray outTemp = in.ulike(); DynamicCustomOp.DynamicCustomOpsBuilder preluBp = DynamicCustomOp.builder("prelu_bp") .addInputs(in, alpha, epsilon) .addOutputs(outTemp, dLdalpha); if (sharedAxes != null) { for (long axis: sharedAxes) { preluBp.addIntegerArguments(axis); } } Nd4j.exec(preluBp.build()); in.assign(outTemp); return new Pair<>(in, dLdalpha); }
Example 6
Source File: BasicWorkspaceTests.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testMmap2() throws Exception { // we don't support MMAP on cuda yet if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda")) return; File tmp = File.createTempFile("tmp", "fdsfdf"); tmp.deleteOnExit(); Nd4jWorkspace.fillFile(tmp, 100000); WorkspaceConfiguration mmap = WorkspaceConfiguration.builder() .policyLocation(LocationPolicy.MMAP) .tempFilePath(tmp.getAbsolutePath()) .build(); MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M3"); INDArray mArray = Nd4j.create(100); mArray.assign(10f); assertEquals(1000f, mArray.sumNumber().floatValue(), 1e-5); ws.notifyScopeLeft(); }
Example 7
Source File: BasicWorkspaceTests.java From nd4j with Apache License 2.0 | 6 votes |
@Test public void testCreateDetached1() throws Exception { try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "ITER")) { INDArray array1 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); INDArray array2 = Nd4j.createUninitializedDetached(5); array2.assign(array1); long reqMemory = 5 * Nd4j.sizeOfDataType(); assertEquals(reqMemory + reqMemory % 8, wsI.getHostOffset()); assertEquals(array1, array2); } }
Example 8
Source File: CpuLapack.java From nd4j with Apache License 2.0 | 5 votes |
@Override public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) { INDArray tau = Nd4j.create( N ) ; int status = LAPACKE_sgeqrf(getColumnOrder(A), M, N, (FloatPointer)A.data().addressPointer(), getLda(A), (FloatPointer)tau.data().addressPointer() ); if( status != 0 ) { throw new BlasException( "Failed to execute sgeqrf", status ) ; } // Copy R ( upper part of Q ) into result if( R != null ) { R.assign( A.get( NDArrayIndex.interval( 0, A.columns() ), NDArrayIndex.all() ) ) ; INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ; for( int i=1 ; i<Math.min( A.rows(), A.columns() ) ; i++ ) { ix[0] = NDArrayIndex.point( i ) ; ix[1] = NDArrayIndex.interval( 0, i ) ; R.put(ix, 0) ; } } status = LAPACKE_sorgqr( getColumnOrder(A), M, N, N, (FloatPointer)A.data().addressPointer(), getLda(A), (FloatPointer)tau.data().addressPointer() ); if( status != 0 ) { throw new BlasException( "Failed to execute sorgqr", status ) ; } }
Example 9
Source File: LogSoftMax.java From nd4j with Apache License 2.0 | 5 votes |
public LogSoftMax(INDArray x, INDArray y, INDArray z, long n) { super(x, y, z, n); //ensure the result is the same //do a reference check here because it's cheaper if (x != z) z.assign(x); }
Example 10
Source File: OpExecutionerTestsC.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testDimensionSoftMax() { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val max = new SoftMax(linspace); Nd4j.getExecutioner().exec((CustomOp) max); linspace.assign(max.outputArguments().get(0)); assertEquals(getFailureMessage(), linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1); }
Example 11
Source File: InvertMatrix.java From nd4j with Apache License 2.0 | 5 votes |
/** * Compute the right pseudo inverse. Input matrix must have full row rank. * * See also: <a href="https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Definition">Moore–Penrose inverse</a> * * @param arr Input matrix * @param inPlace Whether to store the result in {@code arr} * @return Right pseudo inverse of {@code arr} * @exception IllegalArgumentException Input matrix {@code arr} did not have full row rank. */ public static INDArray pRightInvert(INDArray arr, boolean inPlace) { try{ final INDArray inv = arr.transpose().mmul(invert(arr.mmul(arr.transpose()), inPlace)); if (inPlace) arr.assign(inv); return inv; } catch (SingularMatrixException e){ throw new IllegalArgumentException( "Full row rank condition for right pseudo inverse was not met."); } }
Example 12
Source File: EndlessWorkspaceTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * This endless test checks for nested workspaces and cross-workspace memory use * * @throws Exception */ @Test public void endlessTest3() { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration( WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L).build()); Nd4j.getMemoryManager().togglePeriodicGc(false); AtomicLong counter = new AtomicLong(0); while (true) { try (MemoryWorkspace workspace1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS_1")) { INDArray array1 = Nd4j.create(2 * 1024 * 1024); array1.assign(1.0); try (MemoryWorkspace workspace2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS_2")) { INDArray array2 = Nd4j.create(2 * 1024 * 1024); array2.assign(1.0); array1.addi(array2); assertEquals(2.0f, array1.meanNumber().floatValue(), 0.01); if (counter.incrementAndGet() % 1000 == 0) { log.info("{} iterations passed...", counter.get()); System.gc(); } } } } }
Example 13
Source File: OperationProfilerTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testCounter1() { INDArray array = Nd4j.createUninitialized(100); array.assign(10f); array.divi(2f); assertEquals(2, OpProfiler.getInstance().getInvocationsCount()); }
Example 14
Source File: CustomOpsTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testFusedBatchNorm() { INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4); INDArray scale = Nd4j.create(DataType.DOUBLE, 4); scale.assign(0.5); INDArray offset = Nd4j.create(DataType.DOUBLE, 4); offset.assign(2.0); INDArray y = Nd4j.createUninitialized(DataType.DOUBLE, x.shape()); INDArray batchMean = Nd4j.create(4); INDArray batchVar = Nd4j.create(4); FusedBatchNorm op = new FusedBatchNorm(x,scale,offset,0,1, y, batchMean, batchVar); INDArray expectedY = Nd4j.createFromArray(new double[]{1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, 2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}).reshape(x.shape()); INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23., 24., 25., 26.}); INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526, 208.00001526, 208.00001526, 208.00001526}); Nd4j.exec(op); assertArrayEquals(expectedY.shape(), y.shape()); assertArrayEquals(expectedBatchMean.shape(), batchMean.shape()); assertArrayEquals(expectedBatchVar.shape(), batchVar.shape()); }
Example 15
Source File: CpuNDArrayFactory.java From deeplearning4j with Apache License 2.0 | 4 votes |
/** * This method averages input arrays, and returns averaged array * * @param target * @param arrays * @return */ @Override public INDArray average(INDArray target, INDArray[] arrays) { if (arrays == null || arrays.length == 0) throw new RuntimeException("Input arrays are missing"); if (arrays.length == 1) { //Edge case - average 1 array - no op if(target == null){ return null; } return target.assign(arrays[0]); } long len = target != null ? target.length() : arrays[0].length(); PointerPointer dataPointers = new PointerPointer(arrays.length); val firstType = arrays[0].dataType(); for (int i = 0; i < arrays.length; i++) { Nd4j.getCompressor().autoDecompress(arrays[i]); Preconditions.checkArgument(arrays[i].dataType() == firstType, "All arrays must have the same data type"); if (arrays[i].elementWiseStride() != 1) throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays"); if (arrays[i].length() != len) throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); dataPointers.put(i, arrays[i].data().addressPointer()); } nativeOps.average(null, dataPointers, (LongPointer) arrays[0].shapeInfoDataBuffer().addressPointer(), null, null, target == null ? null : target.data().addressPointer(), target == null ? null : (LongPointer) target.shapeInfoDataBuffer().addressPointer(), null, null, arrays.length, len, true); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); return target; }
Example 16
Source File: JcublasLapack.java From nd4j with Apache License 2.0 | 4 votes |
public int ssyev( char _jobz, char _uplo, int N, INDArray A, INDArray R ) { int status = -1 ; int jobz = _jobz == 'V' ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR ; int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER ; if (Nd4j.dataType() != DataBuffer.Type.FLOAT) log.warn("FLOAT ssyev called in DOUBLE environment"); INDArray a = A; if (A.ordering() == 'c') a = A.dup('f'); // FIXME: int cast int M = (int) A.rows() ; if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); cusolverDnContext solverDn = new cusolverDnContext(handle); // synchronized on the solver synchronized (handle) { status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); if( status == 0 ) { // transfer the INDArray into GPU memory CublasPointer xAPointer = new CublasPointer(a, ctx); CublasPointer xRPointer = new CublasPointer(R, ctx); // this output - indicates how much memory we'll need for the real operation DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); status = cusolverDnSsyevd_bufferSize ( solverDn, jobz, uplo, M, (FloatPointer) xAPointer.getDevicePointer(), M, (FloatPointer) xRPointer.getDevicePointer(), (IntPointer)worksizeBuffer.addressPointer() ) ; if (status == CUSOLVER_STATUS_SUCCESS) { int worksize = worksizeBuffer.getInt(0); // allocate memory for the workspace, the non-converging row buffer and a return code Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1})); // Do the actual decomp status = cusolverDnSsyevd(solverDn, jobz, uplo, M, (FloatPointer) xAPointer.getDevicePointer(), M, (FloatPointer) xRPointer.getDevicePointer(), new CudaPointer(workspace).asFloatPointer(), worksize, new CudaPointer(allocator.getPointer(INFO, ctx)).asIntPointer()); allocator.registerAction(ctx, INFO); if( status == 0 ) status = INFO.getInt(0) ; } } } if( status == 0 ) { allocator.registerAction(ctx, R); allocator.registerAction(ctx, a); if (a != A) A.assign(a); } return status ; }
Example 17
Source File: Nd4jTestsComparisonFortran.java From nd4j with Apache License 2.0 | 4 votes |
@Test public void testGemmWithOpsCommonsMath() { List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED); List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED); List<Pair<INDArray, String>> secondT = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, SEED); double[] alpha = {1.0, -0.5, 2.5}; double[] beta = {0.0, -0.25, 1.5}; INDArray cOrig = Nd4j.create(new int[] {3, 4}); Random r = new Random(12345); for (int i = 0; i < cOrig.size(0); i++) { for (int j = 0; j < cOrig.size(1); j++) { cOrig.putScalar(new int[] {i, j}, r.nextDouble()); } } for (int i = 0; i < first.size(); i++) { for (int j = 0; j < second.size(); j++) { for (int k = 0; k < alpha.length; k++) { for (int m = 0; m < beta.length; m++) { System.out.println((String.format("Running iteration %d %d %d %d", i, j, k, m))); INDArray cff = Nd4j.create(cOrig.shape(), 'f'); cff.assign(cOrig); INDArray cft = Nd4j.create(cOrig.shape(), 'f'); cft.assign(cOrig); INDArray ctf = Nd4j.create(cOrig.shape(), 'f'); ctf.assign(cOrig); INDArray ctt = Nd4j.create(cOrig.shape(), 'f'); ctt.assign(cOrig); double a = alpha[k]; double b = beta[k]; Pair<INDArray, String> p1 = first.get(i); Pair<INDArray, String> p1T = firstT.get(i); Pair<INDArray, String> p2 = second.get(j); Pair<INDArray, String> p2T = secondT.get(j); String errorMsgff = getGemmErrorMsg(i, j, false, false, a, b, p1, p2); String errorMsgft = getGemmErrorMsg(i, j, false, true, a, b, p1, p2T); String errorMsgtf = getGemmErrorMsg(i, j, true, false, a, b, p1T, p2); String errorMsgtt = getGemmErrorMsg(i, j, true, true, a, b, p1T, p2T); assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, a, b, 1e-4, 1e-6)); assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, a, b, 1e-4, 1e-6)); assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, a, b, 1e-4, 1e-6)); assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, a, b, 1e-4, 1e-6)); } } } } }
Example 18
Source File: JcublasLapack.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) { INDArray a = A; if (Nd4j.dataType() != DataType.FLOAT) log.warn("FLOAT getrf called in DOUBLE environment"); if (A.ordering() == 'c') a = A.dup('f'); if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); cusolverDnContext solverDn = new cusolverDnContext(handle); // synchronized on the solver synchronized (handle) { int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream())); if (result != 0) throw new BlasException("solverSetStream failed"); // transfer the INDArray into GPU memory CublasPointer xAPointer = new CublasPointer(a, ctx); // this output - indicates how much memory we'll need for the real operation val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1); worksizeBuffer.lazyAllocateHostPointer(); int stat = cusolverDnSgetrf_bufferSize(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M, (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here ); if (stat != CUSOLVER_STATUS_SUCCESS) { throw new BlasException("cusolverDnSgetrf_bufferSize failed", stat); } int worksize = worksizeBuffer.getInt(0); // Now allocate memory for the workspace, the permutation matrix and a return code Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); // Do the actual LU decomp stat = cusolverDnSgetrf(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M, new CudaPointer(workspace).asFloatPointer(), new CudaPointer(allocator.getPointer(IPIV, ctx)).asIntPointer(), new CudaPointer(allocator.getPointer(INFO, ctx)).asIntPointer()); // we do sync to make sure getrf is finished //ctx.syncOldStream(); if (stat != CUSOLVER_STATUS_SUCCESS) { throw new BlasException("cusolverDnSgetrf failed", stat); } } allocator.registerAction(ctx, a); allocator.registerAction(ctx, INFO); allocator.registerAction(ctx, IPIV); if (a != A) A.assign(a); }
Example 19
Source File: BaseUnderSamplingPreProcessor.java From deeplearning4j with Apache License 2.0 | 4 votes |
public INDArray adjustMasks(INDArray label, INDArray labelMask, int minorityLabel, double targetDist) { if (labelMask == null) { labelMask = Nd4j.ones(label.size(0), label.size(2)); } validateData(label, labelMask); INDArray bernoullis = Nd4j.zeros(labelMask.shape()); long currentTimeSliceEnd = label.size(2); //iterate over each tbptt window while (currentTimeSliceEnd > 0) { long currentTimeSliceStart = Math.max(currentTimeSliceEnd - tbpttWindowSize, 0); //get views for current time slice INDArray currentWindowBernoulli = bernoullis.get(NDArrayIndex.all(), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd)); INDArray currentMask = labelMask.get(NDArrayIndex.all(), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd)); INDArray currentLabel; if (label.size(1) == 2) { //if one hot grab the right index currentLabel = label.get(NDArrayIndex.all(), NDArrayIndex.point(minorityLabel), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd)); } else { currentLabel = label.get(NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd)); if (minorityLabel == 0) { currentLabel = currentLabel.rsub(1.0); //rsub(1.0) is equivalent to swapping 0s and 1s } } //calculate required probabilities and write into the view currentWindowBernoulli.assign(calculateBernoulli(currentLabel, currentMask, targetDist)); currentTimeSliceEnd = currentTimeSliceStart; } return Nd4j.getExecutioner().exec( new BernoulliDistribution(Nd4j.createUninitialized(bernoullis.shape()), bernoullis), Nd4j.getRandom()); }
Example 20
Source File: WeightInitUtil.java From deeplearning4j with Apache License 2.0 | 4 votes |
public static INDArray initWeights(double fanIn, double fanOut, long[] shape, WeightInit initScheme, Distribution dist, char order, INDArray paramView) { switch (initScheme) { case DISTRIBUTION: if (dist instanceof OrthogonalDistribution) { dist.sample(paramView.reshape(order, shape)); } else { dist.sample(paramView); } break; case RELU: Nd4j.randn(paramView).muli(FastMath.sqrt(2.0 / fanIn)); //N(0, 2/nIn) break; case RELU_UNIFORM: double u = Math.sqrt(6.0 / fanIn); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn) break; case SIGMOID_UNIFORM: double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut)); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-r, r)); break; case UNIFORM: double a = 1.0 / Math.sqrt(fanIn); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-a, a)); break; case LECUN_UNIFORM: double b = 3.0 / Math.sqrt(fanIn); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-b, b)); break; case XAVIER: Nd4j.randn(paramView).muli(FastMath.sqrt(2.0 / (fanIn + fanOut))); break; case XAVIER_UNIFORM: //As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut)) //Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-s, s)); break; case LECUN_NORMAL: //Fall through: these 3 are equivalent case NORMAL: case XAVIER_FAN_IN: Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn)); break; case XAVIER_LEGACY: Nd4j.randn(paramView).divi(FastMath.sqrt(shape[0] + shape[1])); break; case ZERO: paramView.assign(0.0); break; case ONES: paramView.assign(1.0); break; case IDENTITY: if(shape.length != 2 || shape[0] != shape[1]){ throw new IllegalStateException("Cannot use IDENTITY init with parameters of shape " + Arrays.toString(shape) + ": weights must be a square matrix for identity"); } INDArray ret; if(order == Nd4j.order()){ ret = Nd4j.eye(shape[0]); } else { ret = Nd4j.createUninitialized(shape, order).assign(Nd4j.eye(shape[0])); } INDArray flat = Nd4j.toFlattened(order, ret); paramView.assign(flat); break; case VAR_SCALING_NORMAL_FAN_IN: Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanIn))); break; case VAR_SCALING_NORMAL_FAN_OUT: Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(1.0 / fanOut))); break; case VAR_SCALING_NORMAL_FAN_AVG: Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, Math.sqrt(2.0 / (fanIn + fanOut)))); break; case VAR_SCALING_UNIFORM_FAN_IN: double scalingFanIn = 3.0 / Math.sqrt(fanIn); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn)); break; case VAR_SCALING_UNIFORM_FAN_OUT: double scalingFanOut = 3.0 / Math.sqrt(fanOut); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut)); break; case VAR_SCALING_UNIFORM_FAN_AVG: double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg)); break; default: throw new IllegalStateException("Illegal weight init value: " + initScheme); } return paramView.reshape(order, shape); }