Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#min()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#min() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: AttentionVertex.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Override
public SDVariable defineVertex(SameDiff sameDiff, Map<String, SDVariable> layerInput, Map<String, SDVariable> paramTable, Map<String, SDVariable> maskVars) {
    final SDVariable queries = layerInput.get("queries");
    final SDVariable keys = layerInput.get("keys");
    final SDVariable values = layerInput.get("values");
    final SDVariable mask = maskVars != null ? sameDiff.min(maskVars.get("keys"), maskVars.get("values")): null;

    SDVariable attention;
    if(projectInput){
        val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
        val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
        val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
        val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);

        attention = sameDiff.nn.multiHeadDotProductAttention(getLayerName(), queries, keys, values, Wq, Wk, Wv, Wo, mask, true);
    }else{
        attention = sameDiff.nn.dotProductAttention(getLayerName(), queries, keys, values, mask, true);
    }

    if(maskVars != null){
        return attention.mul(sameDiff.expandDims(maskVars.get("queries"), 1));
    }else{
        return attention;
    }
}
 
Example 2
Source File: GradCheckReductions.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testReductionGradientsSimple() {
    //Test reductions: final and only function
    Nd4j.getRandom().setSeed(12345);

    for (int i = 0; i < 12; i++) {

        SameDiff sd = SameDiff.create();

        boolean skipBackward = false;

        int nOut = 4;
        int minibatch = 10;
        SDVariable input = sd.var("in", new int[]{-1, nOut});

        SDVariable loss;
        String name;
        switch (i) {
            case 0:
                loss = sd.mean("loss", input);
                name = "mean";
                break;
            case 1:
                loss = sd.sum("loss", input);
                name = "sum";
                break;
            case 2:
                loss = sd.standardDeviation("loss", input, true);
                name = "stdev";
                break;
            case 3:
                loss = sd.min("loss", input);
                name = "min";
                break;
            case 4:
                loss = sd.max("loss", input);
                name = "max";
                break;
            case 5:
                loss = sd.variance("loss", input, true);
                name = "variance";
                break;
            case 6:
                loss = sd.prod("loss", input);
                name = "prod";
                break;
            case 7:
                loss = sd.norm1("loss", input);
                name = "norm1";
                break;
            case 8:
                loss = sd.norm2("loss", input);
                name = "norm2";
                break;
            case 9:
                loss = sd.normmax("loss", input);
                name = "normmax";
                break;
            case 10:
                loss = sd.countNonZero("loss", input);
                name = "countNonZero";
                skipBackward = true;
                break;
            case 11:
                loss = sd.countZero("loss", input);
                name = "countZero";
                skipBackward = true;
                break;
            default:
                throw new RuntimeException();
        }


        String msg = "test: " + i + " - " + name;
        log.info("*** Starting test: " + msg);

        INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
        sd.associateArrayWithVariable(inputArr, input);

        if (!skipBackward) {
            boolean ok = GradCheckUtil.checkGradients(sd);
            assertTrue(msg, ok);
        }
    }
}
 
Example 3
Source File: GradCheckReductions.java    From nd4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testReductionGradients1() {
        //Test reductions: final, but *not* the only function
        Nd4j.getRandom().setSeed(12345);

        List<String> allFailed = new ArrayList<>();

        for (int dim : new int[]{0, Integer.MAX_VALUE}) {    //These two cases are equivalent here

            for (int i = 0; i < 10; i++) {

                SameDiff sd = SameDiff.create();

                int nOut = 4;
                int minibatch = 10;
                SDVariable input = sd.var("in", new int[]{-1, nOut});
                SDVariable label = sd.var("label", new int[]{-1, nOut});

                SDVariable diff = input.sub(label);
                SDVariable sqDiff = diff.mul(diff);
                SDVariable msePerEx = sd.mean("msePerEx", sqDiff, 1);

                SDVariable loss;
                String name;
                switch (i) {
                    case 0:
                        loss = sd.mean("loss", msePerEx, dim);
                        name = "mean";
                        break;
                    case 1:
                        loss = sd.sum("loss", msePerEx, dim);
                        name = "sum";
                        break;
                    case 2:
                        loss = sd.standardDeviation("loss", msePerEx, true, dim);
                        name = "stdev";
                        break;
                    case 3:
                        loss = sd.min("loss", msePerEx, dim);
                        name = "min";
                        break;
                    case 4:
                        loss = sd.max("loss", msePerEx, dim);
                        name = "max";
                        break;
                    case 5:
                        loss = sd.variance("loss", msePerEx, true, dim);
                        name = "variance";
                        break;
                    case 6:
                        loss = sd.prod("loss", msePerEx, dim);
                        name = "prod";
                        break;
                    case 7:
                        loss = sd.norm1("loss", msePerEx, dim);
                        name = "norm1";
                        break;
                    case 8:
                        loss = sd.norm2("loss", msePerEx, dim);
                        name = "norm2";
                        break;
                    case 9:
                        loss = sd.normmax("loss", msePerEx, dim);
                        name = "normmax";
                        break;
                    default:
                        throw new RuntimeException();
                }


                String msg = "(test " + i + " - " + name + ", dimension=" + dim + ")";
                log.info("*** Starting test: " + msg);

                INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
                INDArray labelArr = Nd4j.randn(minibatch, nOut).muli(100);

                sd.associateArrayWithVariable(inputArr, input);
                sd.associateArrayWithVariable(labelArr, label);

                try {
                    INDArray out = sd.execAndEndResult();
                    assertNotNull(out);
                    assertArrayEquals(new int[]{1, 1}, out.shape());

//                    System.out.println(sd.asFlatPrint());

                    boolean ok = GradCheckUtil.checkGradients(sd);
                    if (!ok) {
                        allFailed.add(msg);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                    allFailed.add(msg + " - EXCEPTION");
                }
            }
        }

        assertEquals("Failed: " + allFailed, 0, allFailed.size());
    }
 
Example 4
Source File: ReductionOpValidation.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testReductionsBackwards() {
//        for (int i = 0; i < 7; i++) {
        int i=5;
        {

            SameDiff sd = SameDiff.create();

            int nOut = 4;
            int minibatch = 3;
            SDVariable input = sd.var("in", DataType.DOUBLE, new long[]{minibatch, nOut});
            SDVariable label = sd.var("label", DataType.DOUBLE, new long[]{minibatch, nOut});

            SDVariable diff = input.sub(label);
            SDVariable sqDiff = diff.mul(diff);
            SDVariable msePerEx = sd.mean("msePerEx", sqDiff, 1);

            SDVariable loss;    //Scalar value
            String name;
            switch (i) {
                case 0:
                    loss = sd.mean("loss", msePerEx, 0);
                    name = "mean";
                    break;
                case 1:
                    loss = sd.sum("loss", msePerEx, 0);
                    name = "sum";
                    break;
                case 2:
                    loss = sd.standardDeviation("loss", msePerEx, true, 0);
                    name = "stdev";
                    break;
                case 3:
                    loss = sd.min("loss", msePerEx, 0);
                    name = "min";
                    break;
                case 4:
                    loss = sd.max("loss", msePerEx, 0);
                    name = "max";
                    break;
                case 5:
                    loss = sd.variance("loss", msePerEx, true, 0);
                    name = "variance";
                    break;
                case 6:
                    loss = sd.prod("loss", msePerEx, 0);
                    name = "prod";
                    break;
                default:
                    throw new RuntimeException();
            }


            String msg = "test: " + i + " - " + name;
            log.info("*** Starting test: " + msg);

            INDArray inputArr = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
            INDArray labelArr = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);

            sd.associateArrayWithVariable(inputArr, input);
            sd.associateArrayWithVariable(labelArr, label);

            INDArray result = loss.eval();
            assertEquals(1, result.length());

            sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet());
        }
    }