org.deeplearning4j.nn.transferlearning.FineTuneConfiguration Java Examples

The following examples show how to use org.deeplearning4j.nn.transferlearning.FineTuneConfiguration. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TestTransferStatsCollection.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
public void test() throws IOException {

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
                    .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build())
                    .layer(1, new OutputLayer.Builder().activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();


    MultiLayerNetwork net2 =
                    new TransferLearning.Builder(net)
                                    .fineTuneConfiguration(
                                                    new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
                                    .setFeatureExtractor(0).build();

    net2.setListeners(new StatsListener(new InMemoryStatsStorage()));

    //Previosuly: failed on frozen layers
    net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
}
 
Example #2
Source File: AbstractZooModel.java    From wekaDeeplearning4j with GNU General Public License v3.0 5 votes vote down vote up
/**
 * We need to create and set the fine tuning config
 * @return Default fine tuning config
 */
protected FineTuneConfiguration getFineTuneConfig() {
    return new FineTuneConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(new Nesterovs(5e-5))
            .seed(seed)
            .build();
}
 
Example #3
Source File: AbstractZooModel.java    From wekaDeeplearning4j with GNU General Public License v3.0 5 votes vote down vote up
/**
 * We need to create and set the fine tuning config
 * @return Default fine tuning config
 */
protected FineTuneConfiguration getFineTuneConfig() {
    return new FineTuneConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(new Nesterovs(5e-5))
            .seed(seed)
            .build();
}
 
Example #4
Source File: BatchNormalizationTest.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testBatchNorm() throws Exception {

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .updater(new Adam(1e-3))
            .activation(Activation.TANH)
            .list()
            .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build())
            .layer(new BatchNormalization())
            .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build())
            .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nOut(10).build())
            .setInputType(InputType.convolutionalFlat(28, 28, 1))
            .build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, true, 12345), 10);

    net.fit(iter);

    MultiLayerNetwork net2 = new TransferLearning.Builder(net)
            .fineTuneConfiguration(FineTuneConfiguration.builder()
                    .updater(new AdaDelta())
                    .build())
            .removeOutputLayer()
            .addLayer(new BatchNormalization.Builder().nOut(3380).build())
            .addLayer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(3380).nOut(10).build())
            .build();

    net2.fit(iter);
}
 
Example #5
Source File: TrainCifar10Model.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 4 votes vote down vote up
private void train() throws IOException {

        ZooModel zooModel = VGG16.builder().build();
        ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(PretrainedType.CIFAR10);
        log.info(vgg16.summary());

        IUpdater iUpdaterWithDefaultConfig = Updater.ADAM.getIUpdaterWithDefaultConfig();
        iUpdaterWithDefaultConfig.setLrAndSchedule(0.1, null);
        FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
                .seed(1234)
//                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .activation(Activation.RELU)
                .updater(iUpdaterWithDefaultConfig)
                .cudnnAlgoMode(ConvolutionLayer.AlgoMode.NO_WORKSPACE)
                .miniBatch(true)
                .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
                .trainingWorkspaceMode(WorkspaceMode.ENABLED)
                .pretrain(true)
                .backprop(true)
                .build();

        ComputationGraph cifar10 = new TransferLearning.GraphBuilder(vgg16)
                .setWorkspaceMode(WorkspaceMode.ENABLED)
                .fineTuneConfiguration(fineTuneConf)
                .setInputTypes(InputType.convolutionalFlat(ImageUtils.HEIGHT,
                        ImageUtils.WIDTH, 3))
                .removeVertexAndConnections("dense_2_loss")
                .removeVertexAndConnections("dense_2")
                .removeVertexAndConnections("dense_1")
                .removeVertexAndConnections("dropout_1")
                .removeVertexAndConnections("embeddings")
                .removeVertexAndConnections("flatten_1")
                .addLayer("dense_1", new DenseLayer.Builder()
                        .nIn(4096)
                        .nOut(EMBEDDINGS)
                        .activation(Activation.RELU).build(), "block3_pool")
                .addVertex("embeddings", new L2NormalizeVertex(new int[]{}, 1e-12), "dense_1")
                .addLayer("lossLayer", new CenterLossOutputLayer.Builder()
                                .lossFunction(LossFunctions.LossFunction.SQUARED_LOSS)
                                .activation(Activation.SOFTMAX).nIn(EMBEDDINGS).nOut(NUM_POSSIBLE_LABELS)
                                .lambda(LAMBDA).alpha(0.9)
                                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).build(),
                        "embeddings")
                .setOutputs("lossLayer")
                .build();

        log.info(cifar10.summary());
        File rootDir = new File("CarTracking/train_from_video_" + NUM_POSSIBLE_LABELS);
        DataSetIterator dataSetIterator = ImageUtils.createDataSetIterator(rootDir, NUM_POSSIBLE_LABELS, BATCH_SIZE);
        DataSetIterator testIterator = ImageUtils.createDataSetIterator(rootDir, NUM_POSSIBLE_LABELS, BATCH_SIZE);
        cifar10.setListeners(new ScoreIterationListener(2));
        int iEpoch = I_EPOCH;
        while (iEpoch < EPOCH_TRAINING) {
            while (dataSetIterator.hasNext()) {
                DataSet trainMiniBatchData = null;
                try {
                    trainMiniBatchData = dataSetIterator.next();
                } catch (Exception e) {
                    e.printStackTrace();
                }
                cifar10.fit(trainMiniBatchData);
            }
            iEpoch++;

            String modelName = PREFIX + NUM_POSSIBLE_LABELS + "_epoch_data_e" + EMBEDDINGS + "_b" + BATCH_SIZE + "_" + iEpoch + ".zip";
            saveProgress(cifar10, iEpoch, modelName);
            testResults(cifar10, testIterator, iEpoch, modelName);
            dataSetIterator.reset();
            log.info("iEpoch = " + iEpoch);
        }
    }
 
Example #6
Source File: TransferLearningVGG16.java    From Java-Machine-Learning-for-Computer-Vision with MIT License 4 votes vote down vote up
public void train() throws IOException {
    ComputationGraph preTrainedNet = loadVGG16PreTrainedWeights();
    log.info("VGG 16 Architecture");
    log.info(preTrainedNet.summary());

    log.info("Start Downloading NeuralNetworkTrainingData...");

    downloadAndUnzipDataForTheFirstTime();

    log.info("NeuralNetworkTrainingData Downloaded and unzipped");

    neuralNetworkTrainingData = new DataStorage() {
    }.loadData();

    FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
            .learningRate(LEARNING_RATE)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(Updater.NESTEROVS)
            .seed(1234)
            .build();

    ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(preTrainedNet)
            .fineTuneConfiguration(fineTuneConf)
            .setFeatureExtractor(FREEZE_UNTIL_LAYER)
            .removeVertexKeepConnections("predictions")
            .addLayer("predictions",
                    new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                            .nIn(4096)
                            .nOut(NUM_POSSIBLE_LABELS)
                            .weightInit(WeightInit.XAVIER)
                            .activation(Activation.SOFTMAX)
                            .build(),
                    FREEZE_UNTIL_LAYER)
            .build();
    vgg16Transfer.setListeners(new ScoreIterationListener(5));

    log.info("Modified VGG 16 Architecture for transfer learning");
    log.info(vgg16Transfer.summary());

    int iEpoch = 0;
    int iIteration = 0;
    while (iEpoch < EPOCH) {
        while (neuralNetworkTrainingData.getTrainIterator().hasNext()) {
            DataSet trainMiniBatchData = neuralNetworkTrainingData.getTrainIterator().next();
            vgg16Transfer.fit(trainMiniBatchData);
            saveProgressEveryConfiguredInterval(vgg16Transfer, iEpoch, iIteration);
            iIteration++;
        }
        neuralNetworkTrainingData.getTrainIterator().reset();
        iEpoch++;

        evalOn(vgg16Transfer, neuralNetworkTrainingData.getTestIterator(), iEpoch);
    }
}
 
Example #7
Source File: FrozenLayerTest.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testFrozen() {
    DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));

    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1))
                    .activation(Activation.IDENTITY);

    FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build();

    MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
                    .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build())
                    .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build())
                    .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                    LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
                                                    .build())
                    .build());

    modelToFineTune.init();
    List<INDArray> ff = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false);
    INDArray asFrozenFeatures = ff.get(2);

    MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(finetune)
                    .setFeatureExtractor(1).build();

    INDArray paramsLastTwoLayers =
                    Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params());
    MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list()
                    .layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build())
                    .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                    LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
                                                    .build())
                    .build(), paramsLastTwoLayers);

    //        assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf());  //Equal, other than names
    //        assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf());  //Equal, other than names

    //Check: forward pass
    INDArray outNow = modelNow.output(randomData.getFeatures());
    INDArray outNotFrozen = notFrozen.output(asFrozenFeatures);
    assertEquals(outNow, outNotFrozen);

    for (int i = 0; i < 5; i++) {
        notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
        modelNow.fit(randomData);
    }

    INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(),
                    notFrozen.params());
    INDArray act = modelNow.params();
    assertEquals(expected, act);
}