Java Code Examples for org.apache.spark.sql.types.DataTypes#StringType
The following examples show how to use
org.apache.spark.sql.types.DataTypes#StringType .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestRangeRowRule.java From envelope with Apache License 2.0 | 6 votes |
public void testDontIgnoreNulls() { StructType schema = new StructType(new StructField[] { new StructField("name", DataTypes.StringType, false, Metadata.empty()), new StructField("nickname", DataTypes.StringType, false, Metadata.empty()), new StructField("age", DataTypes.IntegerType, false, Metadata.empty()), new StructField("candycrushscore", DataTypes.createDecimalType(), false, Metadata.empty()) }); Map<String, Object> configMap = new HashMap<>(); configMap.put(RangeRowRule.FIELDS_CONFIG, Lists.newArrayList("age")); configMap.put(RangeRowRule.FIELD_TYPE_CONFIG, "int"); configMap.put(RangeRowRule.RANGE_CONFIG, Lists.newArrayList(0,105)); Config config = ConfigFactory.parseMap(configMap); RangeRowRule rule = new RangeRowRule(); assertNoValidationFailures(rule, config); rule.configure(config); rule.configureName("agerange"); Row row1 = new RowWithSchema(schema, "Ian", "Ian", null, new BigDecimal("0.00")); assertFalse("Row should not pass rule", rule.check(row1)); }
Example 2
Source File: TestDeletePlanner.java From envelope with Apache License 2.0 | 6 votes |
@Test public void testPlanner() { List<Row> rows = Lists.newArrayList(RowFactory.create("a", 1, false), RowFactory.create("b", 2, true)); StructType schema = new StructType(new StructField[] { new StructField("field1", DataTypes.StringType, false, null), new StructField("field2", DataTypes.IntegerType, false, null), new StructField("field3", DataTypes.BooleanType, false, null) }); Dataset<Row> data = Contexts.getSparkSession().createDataFrame(rows, schema); BulkPlanner p = new DeletePlanner(); p.configure(ConfigFactory.empty()); List<Tuple2<MutationType, Dataset<Row>>> planned = p.planMutationsForSet(data); assertEquals(1, planned.size()); assertEquals(MutationType.DELETE, planned.get(0)._1()); assertEquals(data, planned.get(0)._2()); }
Example 3
Source File: TypeCastStep.java From bpmn.ai with BSD 3-Clause "New" or "Revised" License | 6 votes |
private DataType mapDataType(List<StructField> datasetFields, String column, String typeConfig) { DataType currentDatatype = getCurrentDataType(datasetFields, column); // when typeConfig is null (no config for this column), return the current DataType if(typeConfig == null) { return currentDatatype; } switch (typeConfig) { case "integer": return DataTypes.IntegerType; case "long": return DataTypes.LongType; case "double": return DataTypes.DoubleType; case "boolean": return DataTypes.BooleanType; case "date": return DataTypes.DateType; case "timestamp": return DataTypes.TimestampType; default: return DataTypes.StringType; } }
Example 4
Source File: TestHBaseOutput.java From envelope with Apache License 2.0 | 6 votes |
@Test public void testGetPartialKey() throws Exception { addEntriesToHBase(); Table table = connection.getTable(TableName.valueOf(TABLE)); scanAndCountTable(table, INPUT_ROWS * 4); Config config = ConfigUtils.configFromResource("/hbase/hbase-output-test.conf").getConfig("output"); config = config.withValue("zookeeper", ConfigValueFactory.fromAnyRef("localhost:" + utility.getZkCluster().getClientPort())); HBaseOutput output = new HBaseOutput(); output.configure(config); StructType partialKeySchema = new StructType(new StructField[] { new StructField("symbol", DataTypes.StringType, false, null) }); List<Row> filters = Lists.newArrayList(); filters.add(new RowWithSchema(partialKeySchema, "AAPL")); filters.add(new RowWithSchema(partialKeySchema, "GOOG")); Iterable<Row> filtered = output.getExistingForFilters(filters); assertEquals(25, Iterables.size(filtered)); }
Example 5
Source File: DataFrames.java From DataVec with Apache License 2.0 | 5 votes |
/** * Convert the DataVec sequence schema to a StructType for Spark, for example for use in * {@link #toDataFrameSequence(Schema, JavaRDD)}} * <b>Note</b>: as per {@link #toDataFrameSequence(Schema, JavaRDD)}}, the StructType has two additional columns added to it:<br> * - Column 0: Sequence UUID (name: {@link #SEQUENCE_UUID_COLUMN}) - a UUID for the original sequence<br> * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position * of this record in the original time series.<br> * These two columns are required if the data is to be converted back into a sequence at a later point, for example * using {@link #toRecordsSequence(DataRowsFacade)} * * @param schema Schema to convert * @return StructType for the schema */ public static StructType fromSchemaSequence(Schema schema) { StructField[] structFields = new StructField[schema.numColumns() + 2]; structFields[0] = new StructField(SEQUENCE_UUID_COLUMN, DataTypes.StringType, false, Metadata.empty()); structFields[1] = new StructField(SEQUENCE_INDEX_COLUMN, DataTypes.IntegerType, false, Metadata.empty()); for (int i = 0; i < schema.numColumns(); i++) { switch (schema.getColumnTypes().get(i)) { case Double: structFields[i + 2] = new StructField(schema.getName(i), DataTypes.DoubleType, false, Metadata.empty()); break; case Integer: structFields[i + 2] = new StructField(schema.getName(i), DataTypes.IntegerType, false, Metadata.empty()); break; case Long: structFields[i + 2] = new StructField(schema.getName(i), DataTypes.LongType, false, Metadata.empty()); break; case Float: structFields[i + 2] = new StructField(schema.getName(i), DataTypes.FloatType, false, Metadata.empty()); break; default: throw new IllegalStateException( "This api should not be used with strings , binary data or ndarrays. This is only for columnar data"); } } return new StructType(structFields); }
Example 6
Source File: DecisionStep.java From envelope with Apache License 2.0 | 5 votes |
private boolean evaluateStepByKeyDecision(Set<Step> steps) { Optional<Step> optionalStep = StepUtils.getStepForName(stepByKeyStepName, steps); if (!optionalStep.isPresent()) { throw new RuntimeException("Unknown decision step's key step: " + stepByValueStepName); } if (!(optionalStep.get() instanceof DataStep)) { throw new RuntimeException("Decision step's key step is not a data step: " + optionalStep.get().getName()); } Dataset<Row> keyDataset = ((DataStep)optionalStep.get()).getData(); if (keyDataset.schema().fields().length != 2 || keyDataset.schema().fields()[0].dataType() != DataTypes.StringType || keyDataset.schema().fields()[1].dataType() != DataTypes.BooleanType) { throw new RuntimeException("Decision step's key step must contain a string column and then a boolean column"); } String keyColumnName = keyDataset.schema().fieldNames()[0]; String whereClause = keyColumnName + " = '" + stepByKeyKey + "'"; Dataset<Row> decisionDataset = keyDataset.where(whereClause); if (decisionDataset.count() != 1) { throw new RuntimeException("Decision step's key step must contain a single record for the given key"); } boolean decision = decisionDataset.collectAsList().get(0).getBoolean(1); return decision; }
Example 7
Source File: Tagger.java From vn.vitk with GNU General Public License v3.0 | 5 votes |
/** * Tags a list of sequences and writes the result to an output file with a * desired output format. * * @param sentences * @param outputFileName * @param outputFormat */ public void tag(List<String> sentences, String outputFileName, OutputFormat outputFormat) { List<Row> rows = new LinkedList<Row>(); for (String sentence : sentences) { rows.add(RowFactory.create(sentence)); } StructType schema = new StructType(new StructField[]{ new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); SQLContext sqlContext = new SQLContext(jsc); DataFrame input = sqlContext.createDataFrame(rows, schema); tag(input, outputFileName, outputFormat); }
Example 8
Source File: TestRangeRowRule.java From envelope with Apache License 2.0 | 5 votes |
@Test public void testAgeRangeInt() { StructType schema = new StructType(new StructField[] { new StructField("name", DataTypes.StringType, false, Metadata.empty()), new StructField("nickname", DataTypes.StringType, false, Metadata.empty()), new StructField("age", DataTypes.IntegerType, false, Metadata.empty()), new StructField("candycrushscore", DataTypes.createDecimalType(), false, Metadata.empty()) }); Map<String, Object> configMap = new HashMap<>(); configMap.put(RangeRowRule.FIELDS_CONFIG, Lists.newArrayList("age")); configMap.put(RangeRowRule.FIELD_TYPE_CONFIG, "int"); configMap.put(RangeRowRule.RANGE_CONFIG, Lists.newArrayList(0,105)); Config config = ConfigFactory.parseMap(configMap); RangeRowRule rule = new RangeRowRule(); assertNoValidationFailures(rule, config); rule.configure(config); rule.configureName("agerange"); Row row1 = new RowWithSchema(schema, "Ian", "Ian", 34, new BigDecimal("0.00")); assertTrue("Row should pass rule", rule.check(row1)); Row row2 = new RowWithSchema(schema, "Webster1", "Websta1", 110, new BigDecimal("450.10")); assertFalse("Row should not pass rule", rule.check(row2)); Row row3 = new RowWithSchema(schema, "", "Ian1", 106, new BigDecimal("450.10")); assertFalse("Row should not pass rule", rule.check(row3)); Row row4 = new RowWithSchema(schema, "First Last", "Ian Last", 105, new BigDecimal("450.10")); assertTrue("Row should pass rule", rule.check(row4)); }
Example 9
Source File: TestRangeRowRule.java From envelope with Apache License 2.0 | 5 votes |
@Test public void testAgeRangeDouble() { StructType schema = new StructType(new StructField[] { new StructField("name", DataTypes.StringType, false, Metadata.empty()), new StructField("nickname", DataTypes.StringType, false, Metadata.empty()), new StructField("age", DataTypes.DoubleType, false, Metadata.empty()), new StructField("candycrushscore", DataTypes.createDecimalType(), false, Metadata.empty()) }); Map<String, Object> configMap = new HashMap<>(); configMap.put(RangeRowRule.FIELDS_CONFIG, Lists.newArrayList("age")); configMap.put(RangeRowRule.FIELD_TYPE_CONFIG, "float"); configMap.put(RangeRowRule.RANGE_CONFIG, Lists.newArrayList(0.1,105.0)); Config config = ConfigFactory.parseMap(configMap); RangeRowRule rule = new RangeRowRule(); assertNoValidationFailures(rule, config); rule.configure(config); rule.configureName("agerange"); Row row1 = new RowWithSchema(schema, "Ian", "Ian", new Float(34.0), new BigDecimal("0.00")); assertTrue("Row should pass rule", rule.check(row1)); Row row2 = new RowWithSchema(schema, "Webster1", "Websta1", new Float(110.0), new BigDecimal("450.10")); assertFalse("Row should not pass rule", rule.check(row2)); Row row3 = new RowWithSchema(schema, "", "Ian1", new Float(110.0), new BigDecimal("450.10")); assertFalse("Row should not pass rule", rule.check(row3)); Row row4 = new RowWithSchema(schema, "First Last", "Ian Last", new Float(100.0), new BigDecimal("450.10")); assertTrue("Row should pass rule", rule.check(row4)); }
Example 10
Source File: TestSparkSchema.java From iceberg with Apache License 2.0 | 5 votes |
@Test public void testFailSparkReadSchemaCombinedWithProjectionWhenSchemaDoesNotContainProjection() throws IOException { String tableLocation = temp.newFolder("iceberg-table").toString(); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.unpartitioned(); tables.create(SCHEMA, spec, null, tableLocation); List<SimpleRecord> expectedRecords = Lists.newArrayList( new SimpleRecord(1, "a") ); Dataset<Row> originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class); originalDf.select("id", "data").write() .format("iceberg") .mode("append") .save(tableLocation); StructType sparkReadSchema = new StructType( new StructField[] { new StructField("data", DataTypes.StringType, true, Metadata.empty()) } ); AssertHelpers.assertThrows("Spark should not allow a projection that is not included in the read schema", org.apache.spark.sql.AnalysisException.class, "cannot resolve '`id`' given input columns: [data]", () -> spark.read() .schema(sparkReadSchema) .format("iceberg") .load(tableLocation) .select("id") ); }
Example 11
Source File: TestDecisionStep.java From envelope with Apache License 2.0 | 5 votes |
@Test public void testPruneByStepKeyTrue() { StructType schema = new StructType(new StructField[] { new StructField("name", DataTypes.StringType, false, Metadata.empty()), new StructField("result", DataTypes.BooleanType, false, Metadata.empty()) }); List<Row> rows = Lists.newArrayList( RowFactory.create("namecheck", false), RowFactory.create("agerange", true) ); Dataset<Row> ds = Contexts.getSparkSession().createDataFrame(rows, schema); step1.setData(ds); Map<String, Object> step2ConfigMap = Maps.newHashMap(); step2ConfigMap.put(Step.DEPENDENCIES_CONFIG, Lists.newArrayList("step1")); step2ConfigMap.put(DecisionStep.IF_TRUE_STEP_NAMES_PROPERTY, Lists.newArrayList("step3", "step7")); step2ConfigMap.put(DecisionStep.DECISION_METHOD_PROPERTY, DecisionStep.STEP_BY_KEY_DECISION_METHOD); step2ConfigMap.put(DecisionStep.STEP_BY_KEY_STEP_PROPERTY, "step1"); step2ConfigMap.put(DecisionStep.STEP_BY_KEY_KEY_PROPERTY, "agerange"); Config step2Config = ConfigFactory.parseMap(step2ConfigMap); RefactorStep step2 = new DecisionStep("step2"); step2.configure(step2Config); steps.add(step2); Set<Step> refactored = step2.refactor(steps); assertEquals(refactored, Sets.newHashSet(step1, step2, step3, step4, step7, step8)); }
Example 12
Source File: TestDecisionStep.java From envelope with Apache License 2.0 | 5 votes |
@Test public void testPruneByStepKeyFalse() { StructType schema = new StructType(new StructField[] { new StructField("name", DataTypes.StringType, false, Metadata.empty()), new StructField("result", DataTypes.BooleanType, false, Metadata.empty()) }); List<Row> rows = Lists.newArrayList( RowFactory.create("namecheck", false), RowFactory.create("agerange", true) ); Dataset<Row> ds = Contexts.getSparkSession().createDataFrame(rows, schema); step1.setData(ds); Map<String, Object> step2ConfigMap = Maps.newHashMap(); step2ConfigMap.put(Step.DEPENDENCIES_CONFIG, Lists.newArrayList("step1")); step2ConfigMap.put(DecisionStep.IF_TRUE_STEP_NAMES_PROPERTY, Lists.newArrayList("step3", "step7")); step2ConfigMap.put(DecisionStep.DECISION_METHOD_PROPERTY, DecisionStep.STEP_BY_KEY_DECISION_METHOD); step2ConfigMap.put(DecisionStep.STEP_BY_KEY_STEP_PROPERTY, "step1"); step2ConfigMap.put(DecisionStep.STEP_BY_KEY_KEY_PROPERTY, "namecheck"); Config step2Config = ConfigFactory.parseMap(step2ConfigMap); RefactorStep step2 = new DecisionStep("step2"); step2.configure(step2Config); steps.add(step2); Set<Step> refactored = step2.refactor(steps); assertEquals(refactored, Sets.newHashSet(step1, step2, step5, step6)); }
Example 13
Source File: TestRangeRowRule.java From envelope with Apache License 2.0 | 5 votes |
@Test public void testAgeRangeDecimal() { StructType schema = new StructType(new StructField[] { new StructField("name", DataTypes.StringType, false, Metadata.empty()), new StructField("nickname", DataTypes.StringType, false, Metadata.empty()), new StructField("age", DataTypes.DoubleType, false, Metadata.empty()), new StructField("candycrushscore", DataTypes.createDecimalType(), false, Metadata.empty()) }); Map<String, Object> configMap = new HashMap<>(); configMap.put(RangeRowRule.FIELDS_CONFIG, Lists.newArrayList("candycrushscore")); configMap.put(RangeRowRule.FIELD_TYPE_CONFIG, "decimal"); configMap.put(RangeRowRule.RANGE_CONFIG, Lists.newArrayList("-1.56","400.45")); Config config = ConfigFactory.parseMap(configMap); RangeRowRule rule = new RangeRowRule(); assertNoValidationFailures(rule, config); rule.configure(config); rule.configureName("agerange"); Row row1 = new RowWithSchema(schema, "Ian", "Ian", 34.0, new BigDecimal("-1.00")); assertTrue("Row should pass rule", rule.check(row1)); Row row2 = new RowWithSchema(schema, "Webster1", "Websta1", 110.0, new BigDecimal("-1.57")); assertFalse("Row should not pass rule", rule.check(row2)); Row row3 = new RowWithSchema(schema, "", "Ian1", 110.0, new BigDecimal("450.10")); assertFalse("Row should not pass rule", rule.check(row3)); Row row4 = new RowWithSchema(schema, "First Last", "Ian Last", 100.0, new BigDecimal("400.45")); assertTrue("Row should pass rule", rule.check(row4)); }
Example 14
Source File: JavaOneHotEncoderExample.java From SparkDemo with MIT License | 5 votes |
public static void main(String[] args) { SparkSession spark = SparkSession .builder() .appName("JavaOneHotEncoderExample") .getOrCreate(); // $example on$ List<Row> data = Arrays.asList( RowFactory.create(0, "a"), RowFactory.create(1, "b"), RowFactory.create(2, "c"), RowFactory.create(3, "a"), RowFactory.create(4, "a"), RowFactory.create(5, "c") ); StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("category", DataTypes.StringType, false, Metadata.empty()) }); Dataset<Row> df = spark.createDataFrame(data, schema); StringIndexerModel indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex") .fit(df); Dataset<Row> indexed = indexer.transform(df); OneHotEncoder encoder = new OneHotEncoder() .setInputCol("categoryIndex") .setOutputCol("categoryVec"); Dataset<Row> encoded = encoder.transform(indexed); encoded.show(); // $example off$ spark.stop(); }
Example 15
Source File: IndexRUtil.java From indexr with Apache License 2.0 | 5 votes |
public static List<StructField> indexrSchemaToSparkSchema(SegmentSchema schema) { List<StructField> fields = new ArrayList<>(); for (ColumnSchema cs : schema.getColumns()) { DataType dataType; switch (cs.getSqlType()) { case INT: dataType = DataTypes.IntegerType; break; case BIGINT: dataType = DataTypes.LongType; break; case FLOAT: dataType = DataTypes.FloatType; break; case DOUBLE: dataType = DataTypes.DoubleType; break; case VARCHAR: dataType = DataTypes.StringType; break; case DATE: dataType = DataTypes.DateType; break; case DATETIME: dataType = DataTypes.TimestampType; break; default: throw new IllegalStateException("Unsupported type: " + cs.getSqlType()); } fields.add(new StructField(cs.getName(), dataType, scala.Boolean.box(false), Metadata.empty())); } return fields; }
Example 16
Source File: FlightDataSourceReader.java From flight-spark-source with Apache License 2.0 | 4 votes |
private DataType sparkFromArrow(FieldType fieldType) { switch (fieldType.getType().getTypeID()) { case Null: return DataTypes.NullType; case Struct: throw new UnsupportedOperationException("have not implemented Struct type yet"); case List: throw new UnsupportedOperationException("have not implemented List type yet"); case FixedSizeList: throw new UnsupportedOperationException("have not implemented FixedSizeList type yet"); case Union: throw new UnsupportedOperationException("have not implemented Union type yet"); case Int: ArrowType.Int intType = (ArrowType.Int) fieldType.getType(); int bitWidth = intType.getBitWidth(); if (bitWidth == 8) { return DataTypes.ByteType; } else if (bitWidth == 16) { return DataTypes.ShortType; } else if (bitWidth == 32) { return DataTypes.IntegerType; } else if (bitWidth == 64) { return DataTypes.LongType; } throw new UnsupportedOperationException("unknown int type with bitwidth " + bitWidth); case FloatingPoint: ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) fieldType.getType(); FloatingPointPrecision precision = floatType.getPrecision(); switch (precision) { case HALF: case SINGLE: return DataTypes.FloatType; case DOUBLE: return DataTypes.DoubleType; } case Utf8: return DataTypes.StringType; case Binary: case FixedSizeBinary: return DataTypes.BinaryType; case Bool: return DataTypes.BooleanType; case Decimal: throw new UnsupportedOperationException("have not implemented Decimal type yet"); case Date: return DataTypes.DateType; case Time: return DataTypes.TimestampType; //note i don't know what this will do! case Timestamp: return DataTypes.TimestampType; case Interval: return DataTypes.CalendarIntervalType; case NONE: return DataTypes.NullType; } throw new IllegalStateException("Unexpected value: " + fieldType); }
Example 17
Source File: JsonSchema.java From sylph with Apache License 2.0 | 4 votes |
public Row deserialize(byte[] messageKey, byte[] message, String topic, int partition, long offset) throws IOException { @SuppressWarnings("unchecked") Map<String, Object> map = MAPPER.readValue(message, Map.class); String[] names = rowTypeInfo.names(); Object[] values = new Object[names.length]; for (int i = 0; i < names.length; i++) { String key = names[i]; switch (key) { case "_topic": values[i] = topic; continue; case "_message": values[i] = new String(message, UTF_8); continue; case "_key": values[i] = new String(messageKey, UTF_8); continue; case "_partition": values[i] = partition; continue; case "_offset": values[i] = offset; continue; } Object value = map.get(key); if (value == null) { continue; } DataType type = rowTypeInfo.apply(i).dataType(); if (type instanceof MapType && ((MapType) type).valueType() == DataTypes.StringType) { scala.collection.mutable.Map convertValue = new scala.collection.mutable.HashMap(); //必须是scala的map for (Map.Entry entry : ((Map<?, ?>) value).entrySet()) { convertValue.put(entry.getKey(), entry.getValue() == null ? null : entry.getValue().toString()); } values[i] = convertValue; } else if (value instanceof ArrayType) { //Class<?> aClass = type.getTypeClass(); //values[i] = MAPPER.convertValue(value, aClass); //todo: Spark List to Array values[i] = value; } else if (type == DataTypes.LongType) { values[i] = ((Number) value).longValue(); } else { values[i] = value; } } return new GenericRowWithSchema(values, rowTypeInfo); }
Example 18
Source File: ConfigurationDataTypes.java From envelope with Apache License 2.0 | 4 votes |
public static DataType getSparkDataType(String typeString) { DataType type; String prec_scale_regex_groups = "\\s*(decimal)\\s*\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)\\s*"; Pattern prec_scale_regex_pattern = Pattern.compile(prec_scale_regex_groups); Matcher prec_scale_regex_matcher = prec_scale_regex_pattern.matcher(typeString); if (prec_scale_regex_matcher.matches()) { int precision = Integer.parseInt(prec_scale_regex_matcher.group(2)); int scale = Integer.parseInt(prec_scale_regex_matcher.group(3)); type = DataTypes.createDecimalType(precision, scale); } else { switch (typeString) { case DECIMAL: type = DataTypes.createDecimalType(); break; case STRING: type = DataTypes.StringType; break; case FLOAT: type = DataTypes.FloatType; break; case DOUBLE: type = DataTypes.DoubleType; break; case BYTE: type = DataTypes.ByteType; break; case SHORT: type = DataTypes.ShortType; break; case INT: type = DataTypes.IntegerType; break; case LONG: type = DataTypes.LongType; break; case BOOLEAN: type = DataTypes.BooleanType; break; case BINARY: type = DataTypes.BinaryType; break; case DATE: type = DataTypes.DateType; break; case TIMESTAMP: type = DataTypes.TimestampType; break; default: throw new RuntimeException("Unsupported or unrecognized field type: " + typeString); } } return type; }
Example 19
Source File: AllButEmptyStringAggregationFunction.java From bpmn.ai with BSD 3-Clause "New" or "Revised" License | 4 votes |
@Override public DataType dataType() { return DataTypes.StringType; }
Example 20
Source File: EntitySalienceAnnotatorAndFeatureExtractorSpark.java From ambiverse-nlu with Apache License 2.0 | 4 votes |
/** * Extract a DataFrame ready for training or testing. * @param jsc * @param documents * @param sqlContext * @return * @throws ResourceInitializationException */ public DataFrame extract(JavaSparkContext jsc, JavaRDD<SCAS> documents, SQLContext sqlContext) throws ResourceInitializationException { Accumulator<Integer> TOTAL_DOCS = jsc.accumulator(0, "TOTAL_DOCS"); Accumulator<Integer> SALIENT_ENTITY_INSTANCES = jsc.accumulator(0, "SALIENT_ENTITY_INSTANCES"); Accumulator<Integer> NON_SALIENT_ENTITY_INSTANCES = jsc.accumulator(0, "NON_SALIENT_ENTITY_INSTANCES"); TrainingSettings trainingSettings = getTrainingSettings(); final SparkSerializableAnalysisEngine ae = EntitySalienceFactory.createEntitySalienceEntityAnnotator(trainingSettings.getEntitySalienceEntityAnnotator()); FeatureExtractor fe = new NYTEntitySalienceFeatureExtractor(); final int featureVectorSize = FeatureSetFactory.createFeatureSet(TrainingSettings.FeatureExtractor.ENTITY_SALIENCE).getFeatureVectorSize(); JavaRDD<TrainingInstance> trainingInstances = documents .map(s -> { TOTAL_DOCS.add(1); Logger tmpLogger = LoggerFactory.getLogger(EntitySalienceFeatureExtractorSpark.class); String docId = JCasUtil.selectSingle(s.getJCas(), DocumentMetaData.class).getDocumentId(); tmpLogger.info("Processing document {}.", docId); //Before processing the document through the Disambiguation Pipeline, add the AIDA settings // in each document. SparkUimaUtils.addSettingsToJCas(s.getJCas(), trainingSettings.getDocumentCoherent(), trainingSettings.getDocumentConfidenceThreshold()); return ae.process(s); }) .flatMap(s -> fe.getTrainingInstances(s.getJCas(), trainingSettings.getFeatureExtractor(), trainingSettings.getPositiveInstanceScalingFactor())); StructType schema = new StructType(new StructField[]{ new StructField("docId", DataTypes.StringType, false, Metadata.empty() ), new StructField("entity", DataTypes.StringType, false, Metadata.empty() ), new StructField("label", DataTypes.DoubleType, false, Metadata.empty() ), new StructField("features", new VectorUDT(), false, Metadata.empty()) }); JavaRDD<Row> withFeatures = trainingInstances.map(ti -> { if (ti.getLabel() == 1.0) { SALIENT_ENTITY_INSTANCES.add(1); } else { NON_SALIENT_ENTITY_INSTANCES.add(1); } Vector vei = FeatureValueInstanceUtils.convertToSparkMLVector(ti, featureVectorSize); return RowFactory.create(ti.getDocId(), ti.getEntityId(), ti.getLabel(), vei); }); return sqlContext.createDataFrame(withFeatures, schema); }