Java Code Examples for org.deeplearning4j.util.ModelSerializer#restoreComputationGraph()
The following examples show how to use
org.deeplearning4j.util.ModelSerializer#restoreComputationGraph() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Vasttext.java From scava with Eclipse Public License 2.0 | 6 votes |
@SuppressWarnings("unchecked") public void loadModel(File file) throws FileNotFoundException, ClassNotFoundException, IOException { HashMap<String, Object> configuration = (HashMap<String, Object>) ModelSerializer.getObjectFromFile(file, "vasttext.config"); multiLabel = (Boolean) configuration.get("multiLabel"); vectorizer = new VasttextTextVectorizer(); vectorizer.loadDictionary(configuration.get("dictionary")); labels = vectorizer.getLabels(); labelsSize = labels.size(); typeVasttext= (String) configuration.get("typeVasttext"); multiLabelActivation = (Double) configuration.get("multiLabelActivation"); if(typeVasttext.equalsIgnoreCase("textAndNumeric")) { vasttextTextAndNumeric=ModelSerializer.restoreComputationGraph(file); } else if(typeVasttext.equalsIgnoreCase("onlyText")) { vasttextText=ModelSerializer.restoreMultiLayerNetwork(file); } else { throw new UnsupportedOperationException("Unknown type of model."); } }
Example 2
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static ComputationGraph testModelSerialization(ComputationGraph net){ ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) ComputationGraphConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; }
Example 3
Source File: DLModel.java From java-ml-projects with Apache License 2.0 | 6 votes |
public static DLModel fromFile(File file) throws Exception { Model model = null; try { System.out.println("Trying to load file as computation graph: " + file); model = ModelSerializer.restoreComputationGraph(file); System.out.println("Loaded Computation Graph."); } catch (Exception e) { try { System.out.println("Failed to load computation graph. Trying to load model."); model = ModelSerializer.restoreMultiLayerNetwork(file); System.out.println("Loaded Multilayernetwork"); } catch (Exception e1) { System.out.println("Give up trying to load file: " + file); throw e; } } return new DLModel(model); }
Example 4
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static ComputationGraph testModelSerialization(ComputationGraph net){ ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) ComputationGraphConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; }
Example 5
Source File: Vgg16DeepLearning4jClassifier.java From vision4j-collection with MIT License | 5 votes |
private void init(File computationGraph) throws IOException { this.vgg16 = ModelSerializer.restoreComputationGraph(computationGraph); this.scaler = new VGG16ImagePreProcessor(); this.imageSize = new ImageSize(224, 224, 3); this.imageLoader = new NativeImageLoader(imageSize.getHeight(), imageSize.getWidth(), imageSize.channels()); ArrayList<String> labels = ImageNetLabels.getLabels(); String[] categoriesArray = Constants.IMAGENET_CATEGORIES; this.categories = new Categories(IntStream.range(0, categoriesArray.length) .mapToObj(i -> new Category(categoriesArray[i], i)) .collect(Collectors.toList())); }
Example 6
Source File: RegressionTest060.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void regressionTestCGLSTM1() throws Exception { File f = Resources.asFile("regression_testing/060/060_ModelSerializer_Regression_CG_LSTM_1.zip"); ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true); ComputationGraphConfiguration conf = net.getConfiguration(); assertEquals(3, conf.getVertices().size()); GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); assertEquals("tanh", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization()); assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5); GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer(); assertEquals("softsign", l1.getActivationFn().toString()); assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization()); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer(); assertEquals(4, l2.getNIn()); assertEquals(5, l2.getNOut()); assertEquals("softmax", l2.getActivationFn().toString()); assertTrue(l2.getLossFn() instanceof LossMCXENT); }
Example 7
Source File: Classifier.java From java-ml-projects with Apache License 2.0 | 5 votes |
public static void init() throws IOException { String modelPath = Properties.classifierModelPath(); labels = Properties.classifierLabels(); int[] format = Properties.classifierInputFormat(); loader = new NativeImageLoader(format[0], format[1], format[2]); model = ModelSerializer.restoreComputationGraph(modelPath); model.init(); }
Example 8
Source File: RegressionTest080.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void regressionTestCGLSTM1() throws Exception { File f = Resources.asFile("regression_testing/080/080_ModelSerializer_Regression_CG_LSTM_1.zip"); ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true); ComputationGraphConfiguration conf = net.getConfiguration(); assertEquals(3, conf.getVertices().size()); GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); assertTrue(l0.getActivationFn() instanceof ActivationTanH); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization()); assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5); GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer(); assertTrue(l1.getActivationFn() instanceof ActivationSoftSign); assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization()); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer(); assertEquals(4, l2.getNIn()); assertEquals(5, l2.getNOut()); assertTrue(l2.getActivationFn() instanceof ActivationSoftmax); assertTrue(l2.getLossFn() instanceof LossMCXENT); }
Example 9
Source File: LocalFileNetResultReference.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public Object getResultModel() throws IOException { Model m; if (isGraph) { m = ModelSerializer.restoreComputationGraph(modelFile, false); } else { m = ModelSerializer.restoreMultiLayerNetwork(modelFile, false); } return m; }
Example 10
Source File: TrainUtil.java From FancyBing with GNU General Public License v3.0 | 5 votes |
public static ComputationGraph loadComputationGraph(String fn, double learningRate) throws Exception { System.err.println("Loading model..."); File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn); ComputationGraph model = ModelSerializer.restoreComputationGraph(locationToSave); int numLayers = model.getNumLayers(); for (int i = 0; i < numLayers; i++) { model.getLayer(i).conf().setLearningRateByParam("W", learningRate); model.getLayer(i).conf().setLearningRateByParam("b", learningRate); } return model; }
Example 11
Source File: PolicyNetService.java From FancyBing with GNU General Public License v3.0 | 5 votes |
private static ComputationGraph loadComputationGraph(String fn) throws Exception { File f = new File(System.getProperty("user.dir") + "/model/" + fn); System.out.println("Loading model " + f); ComputationGraph model = ModelSerializer.restoreComputationGraph(f); return model; }
Example 12
Source File: PolicyNetUtil.java From FancyBing with GNU General Public License v3.0 | 5 votes |
public static ComputationGraph loadComputationGraph(String fn) throws Exception { System.err.println("Loading model..."); // File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn); File locationToSave = new File("D:\\workspace\\fancybing-train\\model\\" + fn); ComputationGraph model = ModelSerializer.restoreComputationGraph(locationToSave); return model; }
Example 13
Source File: ModelGuesser.java From konduit-serving with Apache License 2.0 | 5 votes |
/** * Load the model from the given file path * * @param path the path of the file to "guess" * @return the loaded model * @throws Exception if every model load attempt fails */ public static Model loadModelGuess(String path) throws Exception { try { return ModelSerializer.restoreMultiLayerNetwork(new File(path), true); } catch (Exception e) { log.warn("Tried multi layer network"); try { return ModelSerializer.restoreComputationGraph(new File(path), true); } catch (Exception e1) { log.warn("Tried computation graph"); try { return ModelSerializer.restoreMultiLayerNetwork(new File(path), false); } catch (Exception e4) { try { return ModelSerializer.restoreComputationGraph(new File(path), false); } catch (Exception e5) { try { return KerasModelImport.importKerasModelAndWeights(path); } catch (Exception e2) { log.warn("Tried multi layer network keras"); try { return KerasModelImport.importKerasSequentialModelAndWeights(path); } catch (Exception e3) { throw new ModelGuesserException("Unable to load model from path " + path + " (invalid model file or not a known model type)"); } } } } } } }
Example 14
Source File: CatVsDogRecognition.java From Java-Machine-Learning-for-Computer-Vision with MIT License | 4 votes |
public ComputationGraph loadModel() throws IOException { computationGraph = ModelSerializer.restoreComputationGraph(new File(TRAINED_PATH_MODEL)); return computationGraph; }
Example 15
Source File: BidirectionalTest.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testSerializationCompGraph() throws Exception { for(WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .trainingWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm) .updater(new Adam()) .graphBuilder() .addInputs("in") .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in") .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0") .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) .nIn(10).nOut(10).build(), "1") .setOutputs("2") .build(); ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10}; INDArray in = Nd4j.rand(inshape); INDArray labels = Nd4j.rand(inshape); net1.fit(new DataSet(in, labels)); byte[] bytes; try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { ModelSerializer.writeModel(net1, baos, true); bytes = baos.toByteArray(); } ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true); in = Nd4j.rand(inshape); labels = Nd4j.rand(inshape); INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); assertEquals(out1, out2); net1.setInput(0, in); net2.setInput(0, in); net1.setLabels(labels); net2.setLabels(labels); net1.computeGradientAndScore(); net2.computeGradientAndScore(); assertEquals(net1.score(), net2.score(), 1e-6); assertEquals(net1.gradient().gradient(), net2.gradient().gradient()); } }
Example 16
Source File: KerasZooModel.java From wekaDeeplearning4j with GNU General Public License v3.0 | 4 votes |
@Override public ComputationGraph initPretrained(PretrainedType pretrainedType) throws IOException { String remoteUrl = pretrainedUrl(pretrainedType); if (remoteUrl == null) throw new UnsupportedOperationException( "Pretrained " + pretrainedType + " weights are not available for this model."); // Set up file locations String localFilename = modelPrettyName() + ".zip"; File rootCacheDir = DL4JResources.getDirectory(ResourceType.ZOO_MODEL, modelFamily()); File cachedFile = new File(rootCacheDir, localFilename); // Download the file if necessary if (!cachedFile.exists()) { log.info("Downloading model to " + cachedFile.toString()); FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile); } else { log.info("Using cached model at " + cachedFile.toString()); } // Validate the checksum - ensure this is the correct file long expectedChecksum = pretrainedChecksum(pretrainedType); if (expectedChecksum != 0L) { log.info("Verifying download..."); Checksum adler = new Adler32(); FileUtils.checksum(cachedFile, adler); long localChecksum = adler.getValue(); log.info("Checksum local is " + localChecksum + ", expecting " + expectedChecksum); if (expectedChecksum != localChecksum) { log.error("Checksums do not match. Cleaning up files and failing..."); cachedFile.delete(); throw new IllegalStateException( String.format("Pretrained model file for model %s failed checksum.", this.modelPrettyName())); } } // Load the .zip file to a ComputationGraph try { return ModelSerializer.restoreComputationGraph(cachedFile); } catch (Exception ex) { System.err.println("Failed to load model"); ex.printStackTrace(); return null; } }
Example 17
Source File: Dl4jMlpClassifier.java From wekaDeeplearning4j with GNU General Public License v3.0 | 4 votes |
/** * Custom deserialization method * * @param ois the object input stream */ private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { ClassLoader origLoader = Thread.currentThread().getContextClassLoader(); try { Thread.currentThread().setContextClassLoader( this.getClass().getClassLoader()); // default deserialization ois.defaultReadObject(); // Restore the layers String[] layerConfigs = (String[]) ois.readObject(); layers = new Layer[layerConfigs.length]; for (int i = 0; i < layerConfigs.length; i++) { String layerConfigString = layerConfigs[i]; String[] split = layerConfigString.split("::"); String clsName = split[0]; String layerConfig = split[1]; String[] options = weka.core.Utils.splitOptions(layerConfig); layers[i] = (Layer) weka.core.Utils.forName(Layer.class, clsName, options); } // restore the network model if (isInitializationFinished) { File tmpFile = File.createTempFile("restore", "multiLayer"); tmpFile.deleteOnExit(); BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(tmpFile)); long remaining = modelSize; while (remaining > 0) { int bsize = 10024; if (remaining < 10024) { bsize = (int) remaining; } byte[] buffer = new byte[bsize]; int len = ois.read(buffer); if (len == -1) { throw new IOException( "Reached end of network model prematurely during deserialization."); } bos.write(buffer, 0, len); remaining -= len; } bos.flush(); model = ModelSerializer.restoreComputationGraph(tmpFile, false); } } catch (Exception e) { log.error("Failed to restore serialized model. Error: " + e.getMessage()); e.printStackTrace(); } finally { Thread.currentThread().setContextClassLoader(origLoader); } }
Example 18
Source File: KerasZooModel.java From wekaDeeplearning4j with GNU General Public License v3.0 | 4 votes |
@Override public ComputationGraph initPretrained(PretrainedType pretrainedType) throws IOException { String remoteUrl = pretrainedUrl(pretrainedType); if (remoteUrl == null) throw new UnsupportedOperationException( "Pretrained " + pretrainedType + " weights are not available for this model."); // Set up file locations String localFilename = modelPrettyName() + ".zip"; File rootCacheDir = DL4JResources.getDirectory(ResourceType.ZOO_MODEL, modelFamily()); File cachedFile = new File(rootCacheDir, localFilename); // Download the file if necessary if (!cachedFile.exists()) { log.info("Downloading model to " + cachedFile.toString()); FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile); } else { log.info("Using cached model at " + cachedFile.toString()); } // Validate the checksum - ensure this is the correct file long expectedChecksum = pretrainedChecksum(pretrainedType); if (expectedChecksum != 0L) { log.info("Verifying download..."); Checksum adler = new Adler32(); FileUtils.checksum(cachedFile, adler); long localChecksum = adler.getValue(); log.info("Checksum local is " + localChecksum + ", expecting " + expectedChecksum); if (expectedChecksum != localChecksum) { log.error("Checksums do not match. Cleaning up files and failing..."); cachedFile.delete(); throw new IllegalStateException( String.format("Pretrained model file for model %s failed checksum.", this.modelPrettyName())); } } // Load the .zip file to a ComputationGraph try { return ModelSerializer.restoreComputationGraph(cachedFile); } catch (Exception ex) { System.err.println("Failed to load model"); ex.printStackTrace(); return null; } }
Example 19
Source File: KerasYolo9000PredictTest.java From deeplearning4j with Apache License 2.0 | 3 votes |
@Ignore @Test public void testYoloPredictionImport() throws Exception { int HEIGHT = 416; int WIDTH = 416; INDArray indArray = Nd4j.create(HEIGHT, WIDTH, 3); IMAGE_PREPROCESSING_SCALER.transform(indArray); KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); String h5_FILENAME = "modelimport/keras/examples/yolo/yolo-voc.h5"; ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(h5_FILENAME, false); double[][] priorBoxes = {{1.3221, 1.73145}, {3.19275, 4.00944}, {5.05587, 8.09892}, {9.47112, 4.84053}, {11.2364, 10.0071}}; INDArray priors = Nd4j.create(priorBoxes); ComputationGraph model = new TransferLearning.GraphBuilder(graph) .addLayer("outputs", new org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.Builder() .boundingBoxPriors(priors) .build(), "conv2d_23") .setOutputs("outputs") .build(); ModelSerializer.writeModel(model, DL4J_MODEL_FILE_NAME, false); ComputationGraph computationGraph = ModelSerializer.restoreComputationGraph(new File(DL4J_MODEL_FILE_NAME)); System.out.println(computationGraph.summary(InputType.convolutional(416, 416, 3))); INDArray results = computationGraph.outputSingle(indArray); }
Example 20
Source File: SolverDL4j.java From twse-captcha-solver-dl4j with MIT License | 2 votes |
/** * Creates a new <code>SolverDL4j</code> instance. * * @exception IOException if an error occurs */ public SolverDL4j() throws IOException { InputStream is = SolverDL4j.class.getClass().getResourceAsStream("/model.zip"); model = ModelSerializer.restoreComputationGraph(is); }