Java Code Examples for org.deeplearning4j.nn.multilayer.MultiLayerNetwork#setParameters()
The following examples show how to use
org.deeplearning4j.nn.multilayer.MultiLayerNetwork#setParameters() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: VaeReconstructionProbWithKeyFunction.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public VariationalAutoencoder getVaeLayer() { MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue())); network.init(); INDArray val = ((INDArray) params.value()).unsafeDuplication(); if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcast set parameters"); network.setParameters(val); Layer l = network.getLayer(0); if (!(l instanceof VariationalAutoencoder)) { throw new RuntimeException( "Cannot use VaeReconstructionProbWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass()); } return (VariationalAutoencoder) l; }
Example 2
Source File: VaeReconstructionErrorWithKeyFunction.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public VariationalAutoencoder getVaeLayer() { MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue())); network.init(); INDArray val = ((INDArray) params.value()).unsafeDuplication(); if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcast set parameters"); network.setParameters(val); Layer l = network.getLayer(0); if (!(l instanceof VariationalAutoencoder)) { throw new RuntimeException( "Cannot use VaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE " + "layer as layer 0. Layer type: " + l.getClass()); } return (VariationalAutoencoder) l; }
Example 3
Source File: ScoreFlatMapFunction.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public Iterator<Tuple2<Integer, Double>> call(Iterator<DataSet> dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { return Collections.singletonList(new Tuple2<>(0, 0.0)).iterator(); } DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json)); network.init(); INDArray val = params.value().unsafeDuplication(); //.value() object will be shared by all executors on each machine -> OK, as params are not modified by score function if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcast set parameters"); network.setParameters(val); List<Tuple2<Integer, Double>> out = new ArrayList<>(); while (iter.hasNext()) { DataSet ds = iter.next(); double score = network.score(ds, false); val numExamples = (int) ds.getFeatures().size(0); out.add(new Tuple2<>(numExamples, score * numExamples)); } Nd4j.getExecutioner().commit(); return out.iterator(); }
Example 4
Source File: ScoreExamplesWithKeyFunction.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public Iterator<Tuple2<K, Double>> call(Iterator<Tuple2<K, DataSet>> iterator) throws Exception { if (!iterator.hasNext()) { return Collections.emptyIterator(); } MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); network.init(); INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcast set parameters"); network.setParameters(val); List<Tuple2<K, Double>> ret = new ArrayList<>(); List<DataSet> collect = new ArrayList<>(batchSize); List<K> collectKey = new ArrayList<>(batchSize); int totalCount = 0; while (iterator.hasNext()) { collect.clear(); collectKey.clear(); int nExamples = 0; while (iterator.hasNext() && nExamples < batchSize) { Tuple2<K, DataSet> t2 = iterator.next(); DataSet ds = t2._2(); int n = ds.numExamples(); if (n != 1) throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")"); collect.add(ds); collectKey.add(t2._1()); nExamples += n; } totalCount += nExamples; DataSet data = DataSet.merge(collect); INDArray scores = network.scoreExamples(data, addRegularization); double[] doubleScores = scores.data().asDouble(); for (int i = 0; i < doubleScores.length; i++) { ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i])); } } Nd4j.getExecutioner().commit(); if (log.isDebugEnabled()) { log.debug("Scored {} examples ", totalCount); } return ret.iterator(); }
Example 5
Source File: ScoreExamplesFunction.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Override public Iterator<Double> call(Iterator<DataSet> iterator) throws Exception { if (!iterator.hasNext()) { return Collections.emptyIterator(); } MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); network.init(); INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcast set parameters"); network.setParameters(val); List<Double> ret = new ArrayList<>(); List<DataSet> collect = new ArrayList<>(batchSize); int totalCount = 0; while (iterator.hasNext()) { collect.clear(); int nExamples = 0; while (iterator.hasNext() && nExamples < batchSize) { DataSet ds = iterator.next(); int n = ds.numExamples(); collect.add(ds); nExamples += n; } totalCount += nExamples; DataSet data = DataSet.merge(collect); INDArray scores = network.scoreExamples(data, addRegularization); double[] doubleScores = scores.data().asDouble(); for (double doubleScore : doubleScores) { ret.add(doubleScore); } } Nd4j.getExecutioner().commit(); if (log.isDebugEnabled()) { log.debug("Scored {} examples ", totalCount); } return ret.iterator(); }