Java Code Examples for org.nd4j.linalg.dataset.DataSet#getFeaturesMaskArray()
The following examples show how to use
org.nd4j.linalg.dataset.DataSet#getFeaturesMaskArray() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TransferLearningHelper.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * During training frozen vertices/layers can be treated as "featurizing" the input * The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate * quickly on the smaller unfrozen part of the model * Currently does not support datasets with feature masks * * @param input multidataset to feed into the computation graph with frozen layer vertices * @return a multidataset with input features that are the outputs of the frozen layer vertices and the original labels. */ public DataSet featurize(DataSet input) { if (isGraph) { //trying to featurize for a computation graph if (origGraph.getNumInputArrays() > 1 || origGraph.getNumOutputArrays() > 1) { throw new IllegalArgumentException( "Input or output size to a computation graph is greater than one. Requires use of a MultiDataSet."); } else { if (input.getFeaturesMaskArray() != null) { throw new IllegalArgumentException( "Currently cannot support featurizing datasets with feature masks"); } MultiDataSet inbW = new MultiDataSet(new INDArray[] {input.getFeatures()}, new INDArray[] {input.getLabels()}, null, new INDArray[] {input.getLabelsMaskArray()}); MultiDataSet ret = featurize(inbW); return new DataSet(ret.getFeatures()[0], input.getLabels(), ret.getLabelsMaskArrays()[0], input.getLabelsMaskArray()); } } else { if (input.getFeaturesMaskArray() != null) throw new UnsupportedOperationException("Feature masks not supported with featurizing currently"); return new DataSet(origMLN.feedForwardToLayer(frozenInputLayer + 1, input.getFeatures(), false) .get(frozenInputLayer + 1), input.getLabels(), null, input.getLabelsMaskArray()); } }
Example 2
Source File: DataSetDescriptor.java From deeplearning4j with Apache License 2.0 | 6 votes |
public DataSetDescriptor(DataSet ds)throws Exception{ features = new ArrayDescriptor(ds.getFeatures()); labels = new ArrayDescriptor(ds.getLabels()); INDArray featuresMask = ds.getFeaturesMaskArray(); if (featuresMask == null){ this.featuresMask = null; } else{ this.featuresMask = new ArrayDescriptor(featuresMask); } INDArray labelsMask = ds.getLabelsMaskArray(); if (labelsMask == null){ this.labelsMask = null; } else{ this.labelsMask = new ArrayDescriptor(labelsMask); } preProcessed = ds.isPreProcessed(); }
Example 3
Source File: DL4JSentimentAnalysisExample.java From Java-for-Data-Science with MIT License | 4 votes |
public static void main(String[] args) throws Exception { getModelData(); System.out.println("Total memory = " + Runtime.getRuntime().totalMemory()); int batchSize = 50; int vectorSize = 300; int nEpochs = 5; int truncateReviewsToLength = 300; MultiLayerConfiguration sentimentNN = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .updater(Updater.RMSPROP) .regularization(true).l2(1e-5) .weightInit(WeightInit.XAVIER) .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0) .learningRate(0.0018) .list() .layer(0, new GravesLSTM.Builder().nIn(vectorSize).nOut(200) .activation("softsign").build()) .layer(1, new RnnOutputLayer.Builder().activation("softmax") .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(200).nOut(2).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(sentimentNN); net.init(); net.setListeners(new ScoreIterationListener(1)); WordVectors wordVectors = WordVectorSerializer.loadGoogleModel(new File(GNEWS_VECTORS_PATH), true, false); DataSetIterator trainData = new AsyncDataSetIterator(new SentimentExampleIterator(EXTRACT_DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, true), 1); DataSetIterator testData = new AsyncDataSetIterator(new SentimentExampleIterator(EXTRACT_DATA_PATH, wordVectors, 100, truncateReviewsToLength, false), 1); for (int i = 0; i < nEpochs; i++) { net.fit(trainData); trainData.reset(); Evaluation evaluation = new Evaluation(); while (testData.hasNext()) { DataSet t = testData.next(); INDArray dataFeatures = t.getFeatureMatrix(); INDArray dataLabels = t.getLabels(); INDArray inMask = t.getFeaturesMaskArray(); INDArray outMask = t.getLabelsMaskArray(); INDArray predicted = net.output(dataFeatures, false, inMask, outMask); evaluation.evalTimeSeries(dataLabels, predicted, outMask); } testData.reset(); System.out.println(evaluation.stats()); } }