org.deeplearning4j.nn.transferlearning.FineTuneConfiguration Java Examples
The following examples show how to use
org.deeplearning4j.nn.transferlearning.FineTuneConfiguration.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TestTransferStatsCollection.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void test() throws IOException { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new OutputLayer.Builder().activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); MultiLayerNetwork net2 = new TransferLearning.Builder(net) .fineTuneConfiguration( new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build()) .setFeatureExtractor(0).build(); net2.setListeners(new StatsListener(new InMemoryStatsStorage())); //Previosuly: failed on frozen layers net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10))); }
Example #2
Source File: AbstractZooModel.java From wekaDeeplearning4j with GNU General Public License v3.0 | 5 votes |
/** * We need to create and set the fine tuning config * @return Default fine tuning config */ protected FineTuneConfiguration getFineTuneConfig() { return new FineTuneConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Nesterovs(5e-5)) .seed(seed) .build(); }
Example #3
Source File: AbstractZooModel.java From wekaDeeplearning4j with GNU General Public License v3.0 | 5 votes |
/** * We need to create and set the fine tuning config * @return Default fine tuning config */ protected FineTuneConfiguration getFineTuneConfig() { return new FineTuneConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Nesterovs(5e-5)) .seed(seed) .build(); }
Example #4
Source File: BatchNormalizationTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testBatchNorm() throws Exception { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(12345) .updater(new Adam(1e-3)) .activation(Activation.TANH) .list() .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build()) .layer(new BatchNormalization()) .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build()) .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nOut(10).build()) .setInputType(InputType.convolutionalFlat(28, 28, 1)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, true, 12345), 10); net.fit(iter); MultiLayerNetwork net2 = new TransferLearning.Builder(net) .fineTuneConfiguration(FineTuneConfiguration.builder() .updater(new AdaDelta()) .build()) .removeOutputLayer() .addLayer(new BatchNormalization.Builder().nOut(3380).build()) .addLayer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(3380).nOut(10).build()) .build(); net2.fit(iter); }
Example #5
Source File: TrainCifar10Model.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 4 votes |
private void train() throws IOException { ZooModel zooModel = VGG16.builder().build(); ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(PretrainedType.CIFAR10); log.info(vgg16.summary()); IUpdater iUpdaterWithDefaultConfig = Updater.ADAM.getIUpdaterWithDefaultConfig(); iUpdaterWithDefaultConfig.setLrAndSchedule(0.1, null); FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() .seed(1234) // .weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .activation(Activation.RELU) .updater(iUpdaterWithDefaultConfig) .cudnnAlgoMode(ConvolutionLayer.AlgoMode.NO_WORKSPACE) .miniBatch(true) .inferenceWorkspaceMode(WorkspaceMode.ENABLED) .trainingWorkspaceMode(WorkspaceMode.ENABLED) .pretrain(true) .backprop(true) .build(); ComputationGraph cifar10 = new TransferLearning.GraphBuilder(vgg16) .setWorkspaceMode(WorkspaceMode.ENABLED) .fineTuneConfiguration(fineTuneConf) .setInputTypes(InputType.convolutionalFlat(ImageUtils.HEIGHT, ImageUtils.WIDTH, 3)) .removeVertexAndConnections("dense_2_loss") .removeVertexAndConnections("dense_2") .removeVertexAndConnections("dense_1") .removeVertexAndConnections("dropout_1") .removeVertexAndConnections("embeddings") .removeVertexAndConnections("flatten_1") .addLayer("dense_1", new DenseLayer.Builder() .nIn(4096) .nOut(EMBEDDINGS) .activation(Activation.RELU).build(), "block3_pool") .addVertex("embeddings", new L2NormalizeVertex(new int[]{}, 1e-12), "dense_1") .addLayer("lossLayer", new CenterLossOutputLayer.Builder() .lossFunction(LossFunctions.LossFunction.SQUARED_LOSS) .activation(Activation.SOFTMAX).nIn(EMBEDDINGS).nOut(NUM_POSSIBLE_LABELS) .lambda(LAMBDA).alpha(0.9) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).build(), "embeddings") .setOutputs("lossLayer") .build(); log.info(cifar10.summary()); File rootDir = new File("CarTracking/train_from_video_" + NUM_POSSIBLE_LABELS); DataSetIterator dataSetIterator = ImageUtils.createDataSetIterator(rootDir, NUM_POSSIBLE_LABELS, BATCH_SIZE); DataSetIterator testIterator = ImageUtils.createDataSetIterator(rootDir, NUM_POSSIBLE_LABELS, BATCH_SIZE); cifar10.setListeners(new ScoreIterationListener(2)); int iEpoch = I_EPOCH; while (iEpoch < EPOCH_TRAINING) { while (dataSetIterator.hasNext()) { DataSet trainMiniBatchData = null; try { trainMiniBatchData = dataSetIterator.next(); } catch (Exception e) { e.printStackTrace(); } cifar10.fit(trainMiniBatchData); } iEpoch++; String modelName = PREFIX + NUM_POSSIBLE_LABELS + "_epoch_data_e" + EMBEDDINGS + "_b" + BATCH_SIZE + "_" + iEpoch + ".zip"; saveProgress(cifar10, iEpoch, modelName); testResults(cifar10, testIterator, iEpoch, modelName); dataSetIterator.reset(); log.info("iEpoch = " + iEpoch); } }
Example #6
Source File: TransferLearningVGG16.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 4 votes |
public void train() throws IOException { ComputationGraph preTrainedNet = loadVGG16PreTrainedWeights(); log.info("VGG 16 Architecture"); log.info(preTrainedNet.summary()); log.info("Start Downloading NeuralNetworkTrainingData..."); downloadAndUnzipDataForTheFirstTime(); log.info("NeuralNetworkTrainingData Downloaded and unzipped"); neuralNetworkTrainingData = new DataStorage() { }.loadData(); FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() .learningRate(LEARNING_RATE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.NESTEROVS) .seed(1234) .build(); ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(preTrainedNet) .fineTuneConfiguration(fineTuneConf) .setFeatureExtractor(FREEZE_UNTIL_LAYER) .removeVertexKeepConnections("predictions") .addLayer("predictions", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nIn(4096) .nOut(NUM_POSSIBLE_LABELS) .weightInit(WeightInit.XAVIER) .activation(Activation.SOFTMAX) .build(), FREEZE_UNTIL_LAYER) .build(); vgg16Transfer.setListeners(new ScoreIterationListener(5)); log.info("Modified VGG 16 Architecture for transfer learning"); log.info(vgg16Transfer.summary()); int iEpoch = 0; int iIteration = 0; while (iEpoch < EPOCH) { while (neuralNetworkTrainingData.getTrainIterator().hasNext()) { DataSet trainMiniBatchData = neuralNetworkTrainingData.getTrainIterator().next(); vgg16Transfer.fit(trainMiniBatchData); saveProgressEveryConfiguredInterval(vgg16Transfer, iEpoch, iIteration); iIteration++; } neuralNetworkTrainingData.getTrainIterator().reset(); iEpoch++; evalOn(vgg16Transfer, neuralNetworkTrainingData.getTestIterator(), iEpoch); } }
Example #7
Source File: FrozenLayerTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testFrozen() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY); FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build(); MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) .build()) .build()); modelToFineTune.init(); List<INDArray> ff = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false); INDArray asFrozenFeatures = ff.get(2); MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(finetune) .setFeatureExtractor(1).build(); INDArray paramsLastTwoLayers = Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list() .layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) .build()) .build(), paramsLastTwoLayers); // assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf()); //Equal, other than names // assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf()); //Equal, other than names //Check: forward pass INDArray outNow = modelNow.output(randomData.getFeatures()); INDArray outNotFrozen = notFrozen.output(asFrozenFeatures); assertEquals(outNow, outNotFrozen); for (int i = 0; i < 5; i++) { notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels())); modelNow.fit(randomData); } INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), notFrozen.params()); INDArray act = modelNow.params(); assertEquals(expected, act); }