org.apache.spark.ml.feature.CountVectorizerModel Java Examples
The following examples show how to use
org.apache.spark.ml.feature.CountVectorizerModel.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: CountVectorizerModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 6 votes |
@Override public CountVectorizerModelInfo getModelInfo(final CountVectorizerModel from) { final CountVectorizerModelInfo modelInfo = new CountVectorizerModelInfo(); modelInfo.setMinTF(from.getMinTF()); modelInfo.setVocabulary(from.vocabulary()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(from.getInputCol()); modelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(from.getOutputCol()); modelInfo.setOutputKeys(outputKeys); return modelInfo; }
Example #2
Source File: CountVectorizerModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 6 votes |
@Override public CountVectorizerModelInfo getModelInfo(final CountVectorizerModel from, final DataFrame df) { final CountVectorizerModelInfo modelInfo = new CountVectorizerModelInfo(); modelInfo.setMinTF(from.getMinTF()); modelInfo.setVocabulary(from.vocabulary()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(from.getInputCol()); modelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(from.getOutputCol()); modelInfo.setOutputKeys(outputKeys); return modelInfo; }
Example #3
Source File: JavaCountVectorizerExample.java From SparkDemo with MIT License | 5 votes |
public static void main(String[] args) { SparkSession spark = SparkSession .builder() .appName("JavaCountVectorizerExample") .getOrCreate(); // $example on$ // Input data: Each row is a bag of words from a sentence or document. List<Row> data = Arrays.asList( RowFactory.create(Arrays.asList("a", "b", "c")), RowFactory.create(Arrays.asList("a", "b", "b", "c", "a")) ); StructType schema = new StructType(new StructField [] { new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); Dataset<Row> df = spark.createDataFrame(data, schema); // fit a CountVectorizerModel from the corpus CountVectorizerModel cvModel = new CountVectorizer() .setInputCol("text") .setOutputCol("feature") .setVocabSize(3) .setMinDF(2) .fit(df); // alternatively, define CountVectorizerModel with a-priori vocabulary CountVectorizerModel cvm = new CountVectorizerModel(new String[]{"a", "b", "c"}) .setInputCol("text") .setOutputCol("feature"); cvModel.transform(df).show(false); // $example off$ spark.stop(); }
Example #4
Source File: CountVectorizerModelConverter.java From jpmml-sparkml with GNU Affero General Public License v3.0 | 5 votes |
@Override public List<Feature> encodeFeatures(SparkMLEncoder encoder){ CountVectorizerModel transformer = getTransformer(); DocumentFeature documentFeature = (DocumentFeature)encoder.getOnlyFeature(transformer.getInputCol()); ParameterField documentField = new ParameterField(FieldName.create("document")); ParameterField termField = new ParameterField(FieldName.create("term")); TextIndex textIndex = new TextIndex(documentField.getName(), new FieldRef(termField.getName())) .setTokenize(Boolean.TRUE) .setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()) .setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null); Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets(); for(DocumentFeature.StopWordSet stopWordSet : stopWordSets){ if(stopWordSet.isEmpty()){ continue; } String tokenRE; String wordSeparatorRE = documentFeature.getWordSeparatorRE(); switch(wordSeparatorRE){ case "\\s+": tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)"; break; case "\\W+": tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)"; break; default: throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\""); } Map<String, List<String>> data = new LinkedHashMap<>(); data.put("string", Collections.singletonList(tokenRE)); data.put("stem", Collections.singletonList(" ")); data.put("regex", Collections.singletonList("true")); TextIndexNormalization textIndexNormalization = new TextIndexNormalization(null, PMMLUtil.createInlineTable(data)) .setCaseSensitive(stopWordSet.isCaseSensitive()) .setRecursive(Boolean.TRUE); // Handles consecutive matches. See http://stackoverflow.com/a/25085385 textIndex.addTextIndexNormalizations(textIndexNormalization); } DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, DataType.INTEGER, null, textIndex) .addParameterFields(documentField, termField); encoder.addDefineFunction(defineFunction); List<Feature> result = new ArrayList<>(); String[] vocabulary = transformer.vocabulary(); for(int i = 0; i < vocabulary.length; i++){ String term = vocabulary[i]; if(TermUtil.hasPunctuation(term)){ throw new IllegalArgumentException("Punctuated vocabulary terms (" + term + ") are not supported"); } result.add(new TermFeature(encoder, defineFunction, documentFeature, term)); } return result; }
Example #5
Source File: CMMModel.java From vn.vitk with GNU General Public License v3.0 | 5 votes |
/** * Creates a conditional Markov model. * @param pipelineModel * @param weights * @param markovOrder */ public CMMModel(PipelineModel pipelineModel, Vector weights, MarkovOrder markovOrder, Map<String, Set<Integer>> tagDictionary) { this.pipelineModel = pipelineModel; this.contextExtractor = new ContextExtractor(markovOrder, Constants.REGEXP_FILE); this.weights = weights; this.tags = ((StringIndexerModel)(pipelineModel.stages()[2])).labels(); String[] features = ((CountVectorizerModel)(pipelineModel.stages()[1])).vocabulary(); featureMap = new HashMap<String, Integer>(); for (int j = 0; j < features.length; j++) { featureMap.put(features[j], j); } this.tagDictionary = tagDictionary; }
Example #6
Source File: TransitionBasedParserMLP.java From vn.vitk with GNU General Public License v3.0 | 5 votes |
/** * Creates a transition-based parser using a MLP transition classifier. * @param jsc * @param classifierFileName * @param featureFrame */ public TransitionBasedParserMLP(JavaSparkContext jsc, String classifierFileName, FeatureFrame featureFrame) { this.featureFrame = featureFrame; this.classifier = TransitionClassifier.load(jsc, new Path(classifierFileName, "data").toString()); this.pipelineModel = PipelineModel.load(new Path(classifierFileName, "pipelineModel").toString()); this.transitionName = ((StringIndexerModel)pipelineModel.stages()[2]).labels(); String[] features = ((CountVectorizerModel)(pipelineModel.stages()[1])).vocabulary(); this.featureMap = new HashMap<String, Integer>(); for (int j = 0; j < features.length; j++) { this.featureMap.put(features[j], j); } }
Example #7
Source File: CountVectorizerModelConverter.java From jpmml-sparkml with GNU Affero General Public License v3.0 | 4 votes |
public CountVectorizerModelConverter(CountVectorizerModel transformer){ super(transformer); }
Example #8
Source File: TransitionClassifier.java From vn.vitk with GNU General Public License v3.0 | 4 votes |
/** * Trains a transition classifier on the data frame. * @param jsc * @param graphs * @param featureFrame * @param classifierFileName * @param numHiddenUnits * @return a transition classifier. */ public Transformer trainMLP(JavaSparkContext jsc, List<DependencyGraph> graphs, FeatureFrame featureFrame, String classifierFileName, int numHiddenUnits) { // create a SQLContext this.sqlContext = new SQLContext(jsc); // extract a data frame from these graphs DataFrame dataset = toDataFrame(jsc, graphs, featureFrame); // create a processing pipeline and fit it to the data frame Pipeline pipeline = createPipeline(); PipelineModel pipelineModel = pipeline.fit(dataset); DataFrame trainingData = pipelineModel.transform(dataset); // cache the training data for better performance trainingData.cache(); if (verbose) { trainingData.show(false); } // compute the number of different labels, which is the maximum element // in the 'label' column. trainingData.registerTempTable("dfTable"); Row row = sqlContext.sql("SELECT MAX(label) as maxValue from dfTable").first(); int numLabels = (int)row.getDouble(0); numLabels++; int vocabSize = ((CountVectorizerModel)(pipelineModel.stages()[1])).getVocabSize(); // default is a two-layer MLP int[] layers = {vocabSize, numLabels}; // if user specify a hidden layer, use a 3-layer MLP: if (numHiddenUnits > 0) { layers = new int[3]; layers[0] = vocabSize; layers[1] = numHiddenUnits; layers[2] = numLabels; } MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(128) .setSeed(1234L) .setTol((Double)params.getOrDefault(params.getTolerance())) .setMaxIter((Integer)params.getOrDefault(params.getMaxIter())); MultilayerPerceptronClassificationModel model = classifier.fit(trainingData); // compute precision on the training data // DataFrame result = model.transform(trainingData); DataFrame predictionAndLabel = result.select("prediction", "label"); MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setMetricName("precision"); if (verbose) { System.out.println("N = " + trainingData.count()); System.out.println("D = " + vocabSize); System.out.println("K = " + numLabels); System.out.println("H = " + numHiddenUnits); System.out.println("training precision = " + evaluator.evaluate(predictionAndLabel)); } // save the trained MLP to a file // String classifierPath = new Path(classifierFileName, "data").toString(); jsc.parallelize(Arrays.asList(model), 1).saveAsObjectFile(classifierPath); // save the pipeline model to sub-directory "pipelineModel" // try { String pipelinePath = new Path(classifierFileName, "pipelineModel").toString(); pipelineModel.write().overwrite().save(pipelinePath); } catch (IOException e) { e.printStackTrace(); } return model; }
Example #9
Source File: CountVectorizerModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 4 votes |
@Override public Class<CountVectorizerModel> getSource() { return CountVectorizerModel.class; }
Example #10
Source File: CountVectorizerBridgeTest.java From spark-transformers with Apache License 2.0 | 4 votes |
@Test public void testCountVectorizer() { final List<String[]> input = new ArrayList<>(); input.add(new String[]{"a", "b", "c"}); input.add(new String[]{"a", "b", "b", "c", "a"}); //prepare data JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(1, input.get(0)), RowFactory.create(2, input.get(1)) )); StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); Dataset<Row> df = spark.createDataFrame(jrdd, schema); //train model in spark CountVectorizerModel sparkModel = new CountVectorizer() .setInputCol("text") .setOutputCol("feature") .setVocabSize(3) .setMinDF(2) .fit(df); //Export this model byte[] exportedModel = ModelExporter.export(sparkModel); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //compare predictions List<Row> sparkOutput = sparkModel.transform(df).orderBy("id").select("feature").collectAsList(); for (int i = 0; i < 2; i++) { String[] words = input.get(i); Map<String, Object> data = new HashMap<String, Object>(); data.put(sparkModel.getInputCol(), words); transformer.transform(data); double[] transformedOp = (double[]) data.get(sparkModel.getOutputCol()); double[] sparkOp = ((Vector) sparkOutput.get(i).get(0)).toArray(); assertArrayEquals(transformedOp, sparkOp, 0.01); } }
Example #11
Source File: CountVectorizerModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 4 votes |
@Override public Class<CountVectorizerModel> getSource() { return CountVectorizerModel.class; }
Example #12
Source File: CountVectorizerBridgeTest.java From spark-transformers with Apache License 2.0 | 4 votes |
@Test public void testCountVectorizer() { final List<List<String>> input = new ArrayList<>(); input.add(Arrays.<String>asList("a", "b", "c")); input.add(Arrays.<String>asList("a", "b", "b", "c", "a")); //prepare data JavaRDD<Row> jrdd = sc.parallelize(Arrays.asList( RowFactory.create(1, input.get(0)), RowFactory.create(2, input.get(1)) )); StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); DataFrame df = sqlContext.createDataFrame(jrdd, schema); //train model in spark CountVectorizerModel sparkModel = new CountVectorizer() .setInputCol("text") .setOutputCol("feature") .setVocabSize(3) .setMinDF(2) .fit(df); //Export this model byte[] exportedModel = ModelExporter.export(sparkModel, df); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //compare predictions Row[] sparkOutput = sparkModel.transform(df).orderBy("id").select("feature").collect(); for (int i = 0; i < 2; i++) { Object[] words = input.get(i).toArray(); Map<String, Object> data = new HashMap<String, Object>(); data.put(sparkModel.getInputCol(), words); transformer.transform(data); double[] transformedOp = (double[]) data.get(sparkModel.getOutputCol()); double[] sparkOp = ((Vector) sparkOutput[i].get(0)).toArray(); assertArrayEquals(transformedOp, sparkOp, EPSILON); } }