org.deeplearning4j.nn.api.Model Java Examples
The following examples show how to use
org.deeplearning4j.nn.api.Model.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: InplaceParallelInference.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * This method does forward pass and returns output provided by OutputAdapter * * @param adapter * @param input * @param inputMasks * @param <T> * @return */ public <T> T output(@NonNull ModelAdapter<T> adapter, INDArray[] input, INDArray[] inputMasks, INDArray[] labelsMasks) { val holder = selector.getModelForThisThread(); Model model = null; boolean acquired = false; try { model = holder.acquireModel(); acquired = true; return adapter.apply(model, input, inputMasks, labelsMasks); } catch (InterruptedException e) { throw new RuntimeException(e); } finally { if (model != null && acquired) holder.releaseModel(model); } }
Example #2
Source File: EvaluationRunner.java From deeplearning4j with Apache License 2.0 | 6 votes |
private static void doEval(Model m, IEvaluation[] e, Iterator<DataSet> ds, Iterator<MultiDataSet> mds, int evalBatchSize){ if(m instanceof MultiLayerNetwork){ MultiLayerNetwork mln = (MultiLayerNetwork)m; if(ds != null){ mln.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e); } else { mln.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e); } } else { ComputationGraph cg = (ComputationGraph)m; if(ds != null){ cg.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e); } else { cg.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e); } } }
Example #3
Source File: ModelTupleStreamTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void test() throws Exception { int testsCount = 0; for (int numInputs = 1; numInputs <= 5; ++numInputs) { for (int numOutputs = 1; numOutputs <= 5; ++numOutputs) { for (Model model : new Model[]{ buildMultiLayerNetworkModel(numInputs, numOutputs), buildComputationGraphModel(numInputs, numOutputs) }) { doTest(model, numInputs, numOutputs); ++testsCount; } } } assertEquals(50, testsCount); }
Example #4
Source File: EpochListener.java From wekaDeeplearning4j with GNU General Public License v3.0 | 6 votes |
@Override public void onEpochEnd(Model model) { currentEpoch++; // Skip if this is not an evaluation epoch if (currentEpoch % n != 0) { return; } String s = "Epoch [" + currentEpoch + "/" + numEpochs + "]\n"; if (isIntermediateEvaluationsEnabled) { s += "Train Set: \n" + evaluateDataSetIterator(model, trainIterator, true); if (validationIterator != null) { s += "Validation Set: \n" + evaluateDataSetIterator(model, validationIterator, false); } } log(s); }
Example #5
Source File: ModelGuesserTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testNormalizerInPlace() throws Exception { MultiLayerNetwork net = getNetwork(); File tempFile = testDir.newFile("testNormalizerInPlace.bin"); NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2}))); ModelSerializer.writeModel(net, tempFile, true,normalizer); Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath()); assertEquals(model, net); assertEquals(normalizer, normalizer1); }
Example #6
Source File: IntegrationTestRunner.java From deeplearning4j with Apache License 2.0 | 6 votes |
private static Map<String,INDArray> getFrozenLayerParamCopies(Model m){ Map<String,INDArray> out = new LinkedHashMap<>(); org.deeplearning4j.nn.api.Layer[] layers; if (m instanceof MultiLayerNetwork) { layers = ((MultiLayerNetwork) m).getLayers(); } else { layers = ((ComputationGraph) m).getLayers(); } for(org.deeplearning4j.nn.api.Layer l : layers){ if(l instanceof FrozenLayer){ String paramPrefix; if(m instanceof MultiLayerNetwork){ paramPrefix = l.getIndex() + "_"; } else { paramPrefix = l.conf().getLayer().getLayerName() + "_"; } Map<String,INDArray> paramTable = l.paramTable(); for(Map.Entry<String,INDArray> e : paramTable.entrySet()){ out.put(paramPrefix + e.getKey(), e.getValue().dup()); } } } return out; }
Example #7
Source File: BaseEarlyStoppingTrainer.java From deeplearning4j with Apache License 2.0 | 6 votes |
protected void triggerEpochListeners(boolean epochStart, Model model, int epochNum){ Collection<TrainingListener> listeners; if(model instanceof MultiLayerNetwork){ MultiLayerNetwork n = ((MultiLayerNetwork) model); listeners = n.getListeners(); n.setEpochCount(epochNum); } else if(model instanceof ComputationGraph){ ComputationGraph cg = ((ComputationGraph) model); listeners = cg.getListeners(); cg.getConfiguration().setEpochCount(epochNum); } else { return; } if(listeners != null && !listeners.isEmpty()){ for (TrainingListener l : listeners) { if (epochStart) { l.onEpochStart(model); } else { l.onEpochEnd(model); } } } }
Example #8
Source File: BaseStatsListener.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public void onGradientCalculation(Model model) { int iterCount = getModelInfo(model).iterCount; if (calcFromGradients() && updateConfig.reportingFrequency() > 0 && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) { Gradient g = model.gradient(); if (updateConfig.collectHistograms(StatsType.Gradients)) { gradientHistograms = getHistograms(g.gradientForVariable(), updateConfig.numHistogramBins(StatsType.Gradients)); } if (updateConfig.collectMean(StatsType.Gradients)) { meanGradients = calculateSummaryStats(g.gradientForVariable(), StatType.Mean); } if (updateConfig.collectStdev(StatsType.Gradients)) { stdevGradient = calculateSummaryStats(g.gradientForVariable(), StatType.Stdev); } if (updateConfig.collectMeanMagnitudes(StatsType.Gradients)) { meanMagGradients = calculateSummaryStats(g.gradientForVariable(), StatType.MeanMagnitude); } } }
Example #9
Source File: TrainUtil.java From FancyBing with GNU General Public License v3.0 | 5 votes |
public static String saveModel(String name, Model model, int index, int accuracy) throws Exception { System.err.println("Saving model, don't shutdown..."); try { String fn = name + "_idx_" + index + "_" + accuracy + ".zip"; File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn); boolean saveUpdater = true; //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future ModelSerializer.writeModel(model, locationToSave, saveUpdater); System.err.println("Model saved"); return fn; } catch (IOException e) { System.err.println("Save model failed"); e.printStackTrace(); throw e; } }
Example #10
Source File: BaseOptimizer.java From deeplearning4j with Apache License 2.0 | 5 votes |
public static int getEpochCount(Model model){ if (model instanceof MultiLayerNetwork) { return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount(); } else if (model instanceof ComputationGraph) { return ((ComputationGraph) model).getConfiguration().getEpochCount(); } else { return model.conf().getEpochCount(); } }
Example #11
Source File: IntegrationTestRunner.java From deeplearning4j with Apache License 2.0 | 5 votes |
private static void validateLayerIterCounts(Model m, int expEpoch, int expIter){ //Check that the iteration and epoch counts - on the layers - are synced org.deeplearning4j.nn.api.Layer[] layers; if (m instanceof MultiLayerNetwork) { layers = ((MultiLayerNetwork) m).getLayers(); } else { layers = ((ComputationGraph) m).getLayers(); } for(org.deeplearning4j.nn.api.Layer l : layers){ assertEquals("Epoch count", expEpoch, l.getEpochCount()); assertEquals("Iteration count", expIter, l.getIterationCount()); } }
Example #12
Source File: SystemInfoFilePrintListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void onForwardPass(Model model, List<INDArray> activations) { if(!printOnBackwardPass || printFileTarget == null) return; writeFileWithMessage("forward pass"); }
Example #13
Source File: ModelTupleStreamIntegrationTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
private static Model buildModel() throws Exception { final int numInputs = 3; final int numOutputs = 2; final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list( new OutputLayer.Builder() .nIn(numInputs) .nOut(numOutputs) .activation(Activation.IDENTITY) .lossFunction(LossFunctions.LossFunction.MSE) .build() ) .build(); final MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); final float[] floats = new float[]{ +1, +1, +1, -1, -1, -1, 0, 0 }; // positive weight for first output, negative weight for second output, no biases assertEquals((numInputs+1)*numOutputs, floats.length); final INDArray params = Nd4j.create(floats); model.setParams(params); return model; }
Example #14
Source File: SleepyTrainingListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void iterationDone(Model model, int iteration, int epoch) { sleep(lastIteration.get(), timerIteration); if (lastIteration.get() == null) lastIteration.set(new AtomicLong(System.currentTimeMillis())); else lastIteration.get().set(System.currentTimeMillis()); }
Example #15
Source File: SystemInfoFilePrintListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void onBackwardPass(Model model) { if(!printOnBackwardPass || printFileTarget == null) return; writeFileWithMessage("backward pass"); }
Example #16
Source File: DL4JArbiterStatusReportingListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void iterationDone(Model model, int iteration, int epoch) { if (statusListeners == null) { return; } for (StatusListener sl : statusListeners) { sl.onCandidateIteration(candidateInfo, model, iteration); } }
Example #17
Source File: UpdaterCreator.java From deeplearning4j with Apache License 2.0 | 5 votes |
public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) { if (layer instanceof MultiLayerNetwork) { return new MultiLayerUpdater((MultiLayerNetwork) layer); } else if (layer instanceof ComputationGraph) { return new ComputationGraphUpdater((ComputationGraph) layer); } else { return new LayerUpdater((Layer) layer); } }
Example #18
Source File: ModelGuesser.java From konduit-serving with Apache License 2.0 | 5 votes |
/** * Loads a dl4j zip file (either computation graph or multi layer network) * * @param path the path to the file to load * @return a loaded dl4j model * @throws Exception if loading a dl4j model fails */ public static Model loadDl4jGuess(String path) throws Exception { if (isZipFile(new File(path))) { log.debug("Loading file " + path); boolean compGraph = false; try (ZipFile zipFile = new ZipFile(path)) { List<String> collect = zipFile.stream().map(ZipEntry::getName) .collect(Collectors.toList()); log.debug("Entries " + collect); if (collect.contains(ModelSerializer.COEFFICIENTS_BIN) && collect.contains(ModelSerializer.CONFIGURATION_JSON)) { ZipEntry entry = zipFile.getEntry(ModelSerializer.CONFIGURATION_JSON); log.debug("Loaded configuration"); try (InputStream is = zipFile.getInputStream(entry)) { String configJson = IOUtils.toString(is, StandardCharsets.UTF_8); JSONObject jsonObject = new JSONObject(configJson); if (jsonObject.has("vertexInputs")) { log.debug("Loading computation graph."); compGraph = true; } else { log.debug("Loading multi layer network."); } } } } if (compGraph) { return ModelSerializer.restoreComputationGraph(new File(path)); } else { return ModelSerializer.restoreMultiLayerNetwork(new File(path)); } } return null; }
Example #19
Source File: SystemInfoPrintListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void onForwardPass(Model model, Map<String, INDArray> activations) { if(!printOnForwardPass) return; SystemInfo systemInfo = new SystemInfo(); log.info(SYSTEM_INFO); log.info(systemInfo.toPrettyJSON()); }
Example #20
Source File: BaseOptimizer.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * * @param conf * @param stepFunction * @param trainingListeners * @param model */ public BaseOptimizer(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<TrainingListener> trainingListeners, Model model) { this.conf = conf; this.stepFunction = (stepFunction != null ? stepFunction : getDefaultStepFunctionForOptimizer(this.getClass())); this.trainingListeners = trainingListeners != null ? trainingListeners : new ArrayList<TrainingListener>(); this.model = model; lineMaximizer = new BackTrackLineSearch(model, this.stepFunction, this); lineMaximizer.setStepMax(stepMax); lineMaximizer.setMaxIterations(conf.getMaxNumLineSearchIterations()); }
Example #21
Source File: SleepyTrainingListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void onBackwardPass(Model model) { sleep(lastBP.get(), timerBP); if (lastBP.get() == null) lastBP.set(new AtomicLong(System.currentTimeMillis())); else lastBP.get().set(System.currentTimeMillis()); }
Example #22
Source File: CheckpointListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
protected static String getModelType(Model model){ if(model.getClass() == MultiLayerNetwork.class){ return "MultiLayerNetwork"; } else if(model.getClass() == ComputationGraph.class){ return "ComputationGraph"; } else { return "Model"; } }
Example #23
Source File: ModelTupleStream.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Uses the {@link ModelGuesser#loadModelGuess(InputStream)} method. */ protected Model restoreModel(InputStream inputStream) throws IOException { final File instanceDir = solrResourceLoader.getInstancePath().toFile(); try { return ModelGuesser.loadModelGuess(inputStream, instanceDir); } catch (Exception e) { throw new IOException("Failed to restore model from given file (" + serializedModelFileName + ")", e); } }
Example #24
Source File: SystemInfoFilePrintListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void onForwardPass(Model model, Map<String, INDArray> activations) { if(!printOnForwardPass || printFileTarget == null) return; writeFileWithMessage("forward pass"); }
Example #25
Source File: CheckpointListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
protected static int getEpoch(Model model) { if (model instanceof MultiLayerNetwork) { return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount(); } else if (model instanceof ComputationGraph) { return ((ComputationGraph) model).getConfiguration().getEpochCount(); } else { return model.conf().getEpochCount(); } }
Example #26
Source File: SystemInfoPrintListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void onGradientCalculation(Model model) { if(!printOnGradientCalculation) return; SystemInfo systemInfo = new SystemInfo(); log.info(SYSTEM_INFO); log.info(systemInfo.toPrettyJSON()); }
Example #27
Source File: CheckpointListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public void onEpochEnd(Model model) { int epochsDone = getEpoch(model) + 1; if(saveEveryNEpochs != null && epochsDone > 0 && epochsDone % saveEveryNEpochs == 0){ //Save: saveCheckpoint(model); } //General saving conditions: don't need to check here - will check in iterationDone }
Example #28
Source File: FailureTestingListener.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { boolean b = false; for(FailureTrigger ft : triggers) b |= ft.triggerFailure(callType, iteration, epoch, model); return b; }
Example #29
Source File: InplaceParallelInference.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override protected synchronized Model[] getCurrentModelsFromWorkers() { val models = new Model[holders.size()]; int cnt = 0; for (val h:holders) { models[cnt++] = h.sourceModel; } return models; }
Example #30
Source File: FailureTestingListener.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public void onForwardPass(Model model, List<INDArray> activations) { call(CallType.FORWARD_PASS, model); }