Java Code Examples for org.deeplearning4j.nn.graph.ComputationGraph#getConfiguration()
The following examples show how to use
org.deeplearning4j.nn.graph.ComputationGraph#getConfiguration() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static ComputationGraph testModelSerialization(ComputationGraph net){ ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) ComputationGraphConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; }
Example 2
Source File: DTypeTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static void logUsedClasses(ComputationGraph net) { ComputationGraphConfiguration conf = net.getConfiguration(); for (GraphVertex gv : conf.getVertices().values()) { seenVertices.add(gv.getClass()); if (gv instanceof LayerVertex) { seenLayers.add(((LayerVertex) gv).getLayerConf().getLayer().getClass()); InputPreProcessor ipp = ((LayerVertex) gv).getPreProcessor(); if (ipp != null) { seenPreprocs.add(ipp.getClass()); } } else if (gv instanceof PreprocessorVertex) { seenPreprocs.add(((PreprocessorVertex) gv).getPreProcessor().getClass()); } } }
Example 3
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static ComputationGraph testModelSerialization(ComputationGraph net){ ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) ComputationGraphConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; }
Example 4
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static ComputationGraph testModelSerialization(ComputationGraph net){ ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) ComputationGraphConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; }
Example 5
Source File: TestUtils.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static ComputationGraph testModelSerialization(ComputationGraph net){ ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) ComputationGraphConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; }
Example 6
Source File: RegressionTest080.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void regressionTestCGLSTM1() throws Exception { File f = Resources.asFile("regression_testing/080/080_ModelSerializer_Regression_CG_LSTM_1.zip"); ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true); ComputationGraphConfiguration conf = net.getConfiguration(); assertEquals(3, conf.getVertices().size()); GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); assertTrue(l0.getActivationFn() instanceof ActivationTanH); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization()); assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5); GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer(); assertTrue(l1.getActivationFn() instanceof ActivationSoftSign); assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization()); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer(); assertEquals(4, l2.getNIn()); assertEquals(5, l2.getNOut()); assertTrue(l2.getActivationFn() instanceof ActivationSoftmax); assertTrue(l2.getLossFn() instanceof LossMCXENT); }
Example 7
Source File: RegressionTest071.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void regressionTestCGLSTM1() throws Exception { File f = Resources.asFile("regression_testing/071/071_ModelSerializer_Regression_CG_LSTM_1.zip"); ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true); ComputationGraphConfiguration conf = net.getConfiguration(); assertEquals(3, conf.getVertices().size()); GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); assertEquals("tanh", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization()); assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5); GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer(); assertEquals("softsign", l1.getActivationFn().toString()); assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization()); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer(); assertEquals(4, l2.getNIn()); assertEquals(5, l2.getNOut()); assertEquals("softmax", l2.getActivationFn().toString()); assertTrue(l2.getLossFn() instanceof LossMCXENT); }
Example 8
Source File: RegressionTest060.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void regressionTestCGLSTM1() throws Exception { File f = Resources.asFile("regression_testing/060/060_ModelSerializer_Regression_CG_LSTM_1.zip"); ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true); ComputationGraphConfiguration conf = net.getConfiguration(); assertEquals(3, conf.getVertices().size()); GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); assertEquals("tanh", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization()); assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5); GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer(); assertEquals("softsign", l1.getActivationFn().toString()); assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization()); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer(); assertEquals(4, l2.getNIn()); assertEquals(5, l2.getNOut()); assertEquals("softmax", l2.getActivationFn().toString()); assertTrue(l2.getLossFn() instanceof LossMCXENT); }