org.deeplearning4j.util.ModelSerializer Java Examples
The following examples show how to use
org.deeplearning4j.util.ModelSerializer.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){ MultiLayerNetwork restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the MultiLayerConfiguration is serializable (required by Spark etc) MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); serializeDeserializeJava(conf); return restored; }
Example #2
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static ComputationGraph testModelSerialization(ComputationGraph net){ ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) ComputationGraphConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; }
Example #3
Source File: YOLOModel.java From java-ml-projects with Apache License 2.0 | 6 votes |
public void init() { try { if (Objects.isNull(modelPath)) { yoloModel = (ComputationGraph) YOLO2.builder().build().initPretrained(); setModelClasses(COCO_CLASSES); } else { yoloModel = ModelSerializer.restoreComputationGraph(modelPath); if (!(yoloModel.getOutputLayer(0) instanceof Yolo2OutputLayer)) { throw new Error("The model is not an YOLO model (output layer is not Yolo2OutputLayer)"); } setModelClasses(classes.split("\\,")); } imageLoader = new NativeImageLoader(getInputWidth(), getInputHeight(), getInputChannels(), new ColorConversionTransform(COLOR_BGR2RGB)); loadInputParameters(); } catch (IOException e) { throw new Error("Not able to init the model", e); } }
Example #4
Source File: DLModel.java From java-ml-projects with Apache License 2.0 | 6 votes |
public static DLModel fromFile(File file) throws Exception { Model model = null; try { System.out.println("Trying to load file as computation graph: " + file); model = ModelSerializer.restoreComputationGraph(file); System.out.println("Loaded Computation Graph."); } catch (Exception e) { try { System.out.println("Failed to load computation graph. Trying to load model."); model = ModelSerializer.restoreMultiLayerNetwork(file); System.out.println("Loaded Multilayernetwork"); } catch (Exception e1) { System.out.println("Give up trying to load file: " + file); throw e; } } return new DLModel(model); }
Example #5
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){ MultiLayerNetwork restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the MultiLayerConfiguration is serializable (required by Spark etc) MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); serializeDeserializeJava(conf); return restored; }
Example #6
Source File: InferenceExecutionerFactoryTests.java From konduit-serving with Apache License 2.0 | 6 votes |
@Test public void testComputationGraph() throws Exception { Pair<MultiLayerNetwork, DataNormalization> trainedNetwork = TrainUtils.getTrainedNetwork(); ComputationGraph save = trainedNetwork.getLeft().toComputationGraph(); File dir = testDir.newFolder(); File tmpZip = new File(dir, "dl4j_cg_model.zip"); tmpZip.deleteOnExit(); ModelSerializer.writeModel(save, tmpZip, true); ModelStep modelPipelineStep = Dl4jStep.builder() .inputName("default") .outputName("output") .path(tmpZip.getAbsolutePath()) .build(); Dl4jInferenceExecutionerFactory factory = new Dl4jInferenceExecutionerFactory(); InitializedInferenceExecutionerConfig initializedInferenceExecutionerConfig = factory.create(modelPipelineStep); MultiComputationGraphInferenceExecutioner multiComputationGraphInferenceExecutioner = (MultiComputationGraphInferenceExecutioner) initializedInferenceExecutionerConfig.getInferenceExecutioner(); assertNotNull(multiComputationGraphInferenceExecutioner); assertNotNull(multiComputationGraphInferenceExecutioner.model()); assertNotNull(multiComputationGraphInferenceExecutioner.modelLoader()); }
Example #7
Source File: InferenceExecutionerFactoryTests.java From konduit-serving with Apache License 2.0 | 6 votes |
@Test public void testMultiLayerNetwork() throws Exception { Pair<MultiLayerNetwork, DataNormalization> trainedNetwork = TrainUtils.getTrainedNetwork(); MultiLayerNetwork save = trainedNetwork.getLeft(); File dir = testDir.newFolder(); File tmpZip = new File(dir, "dl4j_mln_model.zip"); tmpZip.deleteOnExit(); ModelSerializer.writeModel(save, tmpZip, true); ModelStep modelPipelineStep = Dl4jStep.builder() .inputName("default") .outputName("output") .path(tmpZip.getAbsolutePath()) .build(); Dl4jInferenceExecutionerFactory factory = new Dl4jInferenceExecutionerFactory(); InitializedInferenceExecutionerConfig initializedInferenceExecutionerConfig = factory.create(modelPipelineStep); MultiLayerNetworkInferenceExecutioner multiLayerNetworkInferenceExecutioner = (MultiLayerNetworkInferenceExecutioner) initializedInferenceExecutionerConfig.getInferenceExecutioner(); assertNotNull(multiLayerNetworkInferenceExecutioner); assertNotNull(multiLayerNetworkInferenceExecutioner.model()); assertNotNull(multiLayerNetworkInferenceExecutioner.modelLoader()); }
Example #8
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){ MultiLayerNetwork restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the MultiLayerConfiguration is serializable (required by Spark etc) MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); serializeDeserializeJava(conf); return restored; }
Example #9
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static ComputationGraph testModelSerialization(ComputationGraph net){ ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) ComputationGraphConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; }
Example #10
Source File: Vasttext.java From scava with Eclipse Public License 2.0 | 6 votes |
public void storeModel(File file) throws IOException { HashMap<String, Object> configuration = new HashMap<String, Object>(); //We do not store the updaters if(vasttextText!=null) { ModelSerializer.writeModel(vasttextText, file, false); configuration.put("typeVasttext", "onlyText"); } else if(vasttextTextAndNumeric!=null) { ModelSerializer.writeModel(vasttextTextAndNumeric, file, false); configuration.put("typeVasttext", "textAndNumeric"); } else throw new UnsupportedOperationException("Train before store model"); configuration.put("multiLabelActivation", multiLabelActivation); configuration.put("multiLabel", multiLabel); configuration.put("dictionary", vectorizer.getDictionary()); ModelSerializer.addObjectToFile(file, "vasttext.config", configuration); }
Example #11
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static ComputationGraph testModelSerialization(ComputationGraph net){ ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) ComputationGraphConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; }
Example #12
Source File: MultiLayerNetwork.java From deeplearning4j with Apache License 2.0 | 5 votes |
private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { val mln = ModelSerializer.restoreMultiLayerNetwork(ois, true); this.defaultConfiguration = mln.defaultConfiguration.clone(); this.layerWiseConfigurations = mln.layerWiseConfigurations.clone(); this.init(); this.flattenedParams.assign(mln.flattenedParams); int numWorkingMem = 2 * (layerWiseConfigurations.getConfs().size() + layerWiseConfigurations.getInputPreProcessors().size()); WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem); WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(layerWiseConfigurations.getConfs().size()); if (mln.getUpdater() != null && mln.getUpdater(false).getStateViewArray() != null) this.getUpdater(true).getStateViewArray().assign(mln.getUpdater(false).getStateViewArray()); }
Example #13
Source File: TestUtils.java From konduit-serving with Apache License 2.0 | 5 votes |
public static InferenceConfiguration getConfig(TemporaryFolder trainDir) throws Exception { Pair<MultiLayerNetwork, DataNormalization> multiLayerNetwork = TrainUtils.getTrainedNetwork(); File modelSave = trainDir.newFile("model.zip"); ModelSerializer.writeModel(multiLayerNetwork.getFirst(), modelSave, false); Schema.Builder schemaBuilder = new Schema.Builder(); schemaBuilder.addColumnDouble("petal_length") .addColumnDouble("petal_width") .addColumnDouble("sepal_width") .addColumnDouble("sepal_height"); Schema inputSchema = schemaBuilder.build(); Schema.Builder outputSchemaBuilder = new Schema.Builder(); outputSchemaBuilder.addColumnDouble("setosa"); outputSchemaBuilder.addColumnDouble("versicolor"); outputSchemaBuilder.addColumnDouble("virginica"); Schema outputSchema = outputSchemaBuilder.build(); ServingConfig servingConfig = ServingConfig.builder() .createLoggingEndpoints(true) .build(); Dl4jStep modelPipelineStep = Dl4jStep.builder() .inputName("default") .inputColumnName("default", SchemaTypeUtils.columnNames(inputSchema)) .inputSchema("default", SchemaTypeUtils.typesForSchema(inputSchema)) .outputSchema("default", SchemaTypeUtils.typesForSchema(outputSchema)) .path(modelSave.getAbsolutePath()) .outputColumnName("default", SchemaTypeUtils.columnNames(outputSchema)) .build(); return InferenceConfiguration.builder() .servingConfig(servingConfig) .step(modelPipelineStep) .build(); }
Example #14
Source File: ModelGuesser.java From konduit-serving with Apache License 2.0 | 5 votes |
/** * Load the model from the given file path * * @param path the path of the file to "guess" * @return the loaded model * @throws Exception if every model load attempt fails */ public static Model loadModelGuess(String path) throws Exception { try { return ModelSerializer.restoreMultiLayerNetwork(new File(path), true); } catch (Exception e) { log.warn("Tried multi layer network"); try { return ModelSerializer.restoreComputationGraph(new File(path), true); } catch (Exception e1) { log.warn("Tried computation graph"); try { return ModelSerializer.restoreMultiLayerNetwork(new File(path), false); } catch (Exception e4) { try { return ModelSerializer.restoreComputationGraph(new File(path), false); } catch (Exception e5) { try { return KerasModelImport.importKerasModelAndWeights(path); } catch (Exception e2) { log.warn("Tried multi layer network keras"); try { return KerasModelImport.importKerasSequentialModelAndWeights(path); } catch (Exception e3) { throw new ModelGuesserException("Unable to load model from path " + path + " (invalid model file or not a known model type)"); } } } } } } }
Example #15
Source File: OCNNOutputLayerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testLabelProbabilities() throws Exception { Nd4j.getRandom().setSeed(42); DataSetIterator dataSetIterator = getNormalizedIterator(); MultiLayerNetwork network = getSingleLayer(); DataSet next = dataSetIterator.next(); DataSet filtered = next.filterBy(new int[]{0, 1}); for (int i = 0; i < 10; i++) { network.setEpochCount(i); network.getLayerWiseConfigurations().setEpochCount(i); network.fit(filtered); } DataSet anomalies = next.filterBy(new int[] {2}); INDArray output = network.output(anomalies.getFeatures()); INDArray normalOutput = network.output(anomalies.getFeatures(),false); assertEquals(output.lt(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(), normalOutput.eq(0.0).castTo(Nd4j.defaultFloatingPointType()).sumNumber().doubleValue(),1e-1); // System.out.println("Labels " + anomalies.getLabels()); // System.out.println("Anomaly output " + normalOutput); // System.out.println(output); INDArray normalProbs = network.output(filtered.getFeatures()); INDArray outputForNormalSamples = network.output(filtered.getFeatures(),false); System.out.println("Normal probabilities " + normalProbs); System.out.println("Normal raw output " + outputForNormalSamples); File tmpFile = new File(testDir.getRoot(),"tmp-file-" + UUID.randomUUID().toString()); ModelSerializer.writeModel(network,tmpFile,true); tmpFile.deleteOnExit(); MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile); assertEquals(network.params(),multiLayerNetwork.params()); assertEquals(network.numParams(),multiLayerNetwork.numParams()); }
Example #16
Source File: TrainCifar10Model.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 5 votes |
public void loadTrainedModel(String preTrainedCifarModel) throws IOException { File file = new File(MODEL_SAVE_PATH + preTrainedCifarModel); log.info("loading model " + file); cifar10Transfer = ModelSerializer. restoreComputationGraph(file); log.info(cifar10Transfer.summary()); }
Example #17
Source File: TransferLearningVGG16.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 5 votes |
private void saveProgressEveryConfiguredInterval(ComputationGraph vgg16Transfer, int iEpoch, int iIteration) throws IOException { if (iIteration % SAVING_INTERVAL == 0 && iIteration != 0) { ModelSerializer.writeModel(vgg16Transfer, new File(SAVING_PATH + iIteration + "_epoch_" + iEpoch + ".zip"), false); evalOn(vgg16Transfer, neuralNetworkTrainingData.getDevIterator(), iIteration); } }
Example #18
Source File: Model.java From gluon-samples with BSD 3-Clause "New" or "Revised" License | 5 votes |
private void loadModelLocal() { System.out.println("******LOAD TRAINED MODEL (local)******"); try { InputStream is = Model.class.getResourceAsStream("/mymodel.zip"); MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(is); is.close(); nnModel.set(network); } catch (Throwable t) { t.printStackTrace(); } }
Example #19
Source File: TestComputationGraphNetwork.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testEpochCounter() throws Exception { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() .graphBuilder() .addInputs("in") .addLayer("out", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in") .setOutputs("out") .build(); ComputationGraph net = new ComputationGraph(conf); net.init(); assertEquals(0, net.getConfiguration().getEpochCount()); DataSetIterator iter = new IrisDataSetIterator(150, 150); for( int i=0; i<4; i++ ){ assertEquals(i, net.getConfiguration().getEpochCount()); net.fit(iter); assertEquals(i+1, net.getConfiguration().getEpochCount()); } assertEquals(4, net.getConfiguration().getEpochCount()); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(4, restored.getConfiguration().getEpochCount()); }
Example #20
Source File: DeepAutoEncoderExample.java From Java-for-Data-Science with MIT License | 5 votes |
public void retrieveModel() { try { modelFile = new File("savedModel"); MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile); } catch (IOException ex) { ex.printStackTrace(); } }
Example #21
Source File: ModelSavingCallback.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * This method saves model * * @param model * @param filename */ protected void save(Model model, String filename) { try { ModelSerializer.writeModel(model, filename, true); } catch (IOException e) { throw new RuntimeException(e); } }
Example #22
Source File: ParallelInferenceTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Before public void setUp() throws Exception { if (model == null) { File file = Resources.asFile("models/LenetMnistMLN.zip"); model = ModelSerializer.restoreMultiLayerNetwork(file, true); iterator = new MnistDataSetIterator(1, false, 12345); } }
Example #23
Source File: ModelGuesser.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * A facade for {@link ModelSerializer#restoreNormalizerFromFile(File)} * @param path the path to the file * @return the loaded normalizer */ public static Normalizer<?> loadNormalizer(String path) { try { return ModelSerializer.restoreNormalizerFromFile(new File(path)); } catch (IOException e){ throw new RuntimeException(e); } }
Example #24
Source File: Classifier.java From java-ml-projects with Apache License 2.0 | 5 votes |
public static void init() throws IOException { String modelPath = Properties.classifierModelPath(); labels = Properties.classifierLabels(); int[] format = Properties.classifierInputFormat(); loader = new NativeImageLoader(format[0], format[1], format[2]); model = ModelSerializer.restoreComputationGraph(modelPath); model.init(); }
Example #25
Source File: RegressionTest071.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void regressionTestLSTM1() throws Exception { File f = Resources.asFile("regression_testing/071/071_ModelSerializer_Regression_LSTM_1.zip"); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); assertEquals(3, conf.getConfs().size()); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); assertEquals("tanh", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization()); assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5); GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) conf.getConf(1).getLayer(); assertEquals("softsign", l1.getActivationFn().toString()); assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization()); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); RnnOutputLayer l2 = (RnnOutputLayer) conf.getConf(2).getLayer(); assertEquals(4, l2.getNIn()); assertEquals(5, l2.getNOut()); assertEquals("softmax", l2.getActivationFn().toString()); assertTrue(l2.getLossFn() instanceof LossMCXENT); }
Example #26
Source File: RegressionTest071.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void regressionTestMLP1() throws Exception { File f = Resources.asFile("regression_testing/071/071_ModelSerializer_Regression_MLP_1.zip"); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); assertEquals(2, conf.getConfs().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); assertEquals("relu", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); OutputLayer l1 = (OutputLayer) conf.getConf(1).getLayer(); assertEquals("softmax", l1.getActivationFn().toString()); assertTrue(l1.getLossFn() instanceof LossMCXENT); assertEquals(4, l1.getNIn()); assertEquals(5, l1.getNOut()); assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); long numParams = (int)net.numParams(); assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); int updaterSize = (int) new Nesterovs().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); }
Example #27
Source File: SaveLoadMultiLayerNetwork.java From dl4j-tutorials with MIT License | 5 votes |
public static void main(String[] args) throws Exception { //Define a simple MultiLayerNetwork: MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)) .list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()) .backprop(true).pretrain(false).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); //Save the model File locationToSave = new File("model/MyMultiLayerNetwork.zip"); //Where to save the network. Note: the file is in .zip format - can be opened externally /** * 主要是用于保存模型的更新器信息 * 如果模型保存之后还打算继续训练,则进行保存 -> true 才能根据后面的数据进行增量更新 * 如果不打算继续训练 -> 模型定型之后,false */ boolean saveUpdater = true; //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future ModelSerializer.writeModel(net, locationToSave, saveUpdater); //Load the model MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(locationToSave); System.out.println("Saved and loaded parameters are equal: " + net.params().equals(restored.params())); System.out.println("Saved and loaded configurations are equal: " + net.getLayerWiseConfigurations().equals(restored.getLayerWiseConfigurations())); }
Example #28
Source File: CheckpointListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Load a MultiLayerNetwork for the given checkpoint number * * @param rootDir The directory that the checkpoint resides in * @param checkpointNum Checkpoint model to load * @return The loaded model */ public static MultiLayerNetwork loadCheckpointMLN(File rootDir, int checkpointNum){ File f = getFileForCheckpoint(rootDir, checkpointNum); try { return ModelSerializer.restoreMultiLayerNetwork(f, true); } catch (IOException e){ throw new RuntimeException(e); } }
Example #29
Source File: Gan11Exemple.java From dl4j-tutorials with MIT License | 5 votes |
public void saveModel(String modelName) { try { ModelSerializer.writeModel(net, modelName, true); } catch (IOException e) { e.printStackTrace(); } }
Example #30
Source File: Main.java From gluon-samples with BSD 3-Clause "New" or "Revised" License | 5 votes |
public static void main(String[] args) throws Exception { File f = new File(savedModelLocation); MultiLayerNetwork model = null; if (f.exists()) { LOGGER.info("Model exists, restore it"); model = ModelSerializer.restoreMultiLayerNetwork(savedModelLocation); utils.evaluateModel(model); } else { LOGGER.info("Create model"); model = utils.createModel(); LOGGER.info("Train model"); utils.trainModel(model, true, null, -1); LOGGER.info("Save model"); utils.saveModel(model, savedModelLocation); LOGGER.info("Eval model"); utils.evaluateModel(model); } LOGGER.info("Run tests"); runTests(model); LOGGER.info("Evaluate model after tests"); utils.evaluateModel(model); LOGGER.info("Correct Image"); correctImage(model, Main.class.getResourceAsStream("/mytestdata/3b.png"),3); utils.evaluateModel(model); }