Java Code Examples for org.apache.flink.ml.api.misc.param.Params#get()
The following examples show how to use
org.apache.flink.ml.api.misc.param.Params#get() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: GlmModelDataConverter.java From Alink with Apache License 2.0 | 6 votes |
/** * * @param meta The model meta data. * @param data The model concrete data. * @return GlmModelData */ @Override public GlmModelData deserializeModel(Params meta, Iterable<String> data) { GlmModelData modelData = new GlmModelData(); modelData.featureColNames = meta.get(GlmTrainParams.FEATURE_COLS); modelData.offsetColName = meta.get(GlmTrainParams.OFFSET_COL); modelData.weightColName = meta.get(GlmTrainParams.WEIGHT_COL); modelData.labelColName = meta.get(GlmTrainParams.LABEL_COL); modelData.familyName = meta.get(GlmTrainParams.FAMILY); modelData.variancePower = meta.get(GlmTrainParams.VARIANCE_POWER); modelData.linkName = meta.get(GlmTrainParams.LINK); modelData.linkPower = meta.get(GlmTrainParams.LINK_POWER); modelData.fitIntercept = meta.get(GlmTrainParams.FIT_INTERCEPT); modelData.regParam = meta.get(GlmTrainParams.REG_PARAM); modelData.numIter = meta.get(GlmTrainParams.MAX_ITER); modelData.epsilon = meta.get(GlmTrainParams.EPSILON); Iterator<String> dataIterator = data.iterator(); modelData.coefficients = JsonConverter.fromJson(dataIterator.next(), double[].class); modelData.intercept = JsonConverter.fromJson(dataIterator.next(), double.class); modelData.diagInvAtWA = JsonConverter.fromJson(dataIterator.next(), double[].class); return modelData; }
Example 2
Source File: DirectReader.java From Alink with Apache License 2.0 | 6 votes |
/** * Create data bridge from batch operator. * The type of result DataBridge is the one with matching policy in global configuration. * * * @param model the operator to collect data. * @return the created DataBridge. */ public static DataBridge collect(BatchOperator<?> model) { final Params globalParams = DirectReader.readProperties(); final String policy = globalParams.get(POLICY_KEY); for (DataBridgeGenerator generator : ServiceLoader.load(DataBridgeGenerator.class, DirectReader.class.getClassLoader())) { if (policy.equals(generator .getClass() .getAnnotation(DataBridgeGeneratorPolicy.class) .policy() )) { return generator.generate(model, globalParams); } } throw new IllegalArgumentException("Can not find the policy: " + policy); }
Example 3
Source File: ParamsTest.java From flink with Apache License 2.0 | 6 votes |
@Test public void testGetAliasParam() { ParamInfo <String> predResultColName = ParamInfoFactory .createParamInfo("predResultColName", String.class) .setDescription("Column name of predicted result.") .setRequired() .setAlias(new String[] {"predColName", "outputColName"}) .build(); Params params = Params.fromJson("{\"predResultColName\":\"\\\"f0\\\"\"}"); Assert.assertEquals("f0", params.get(predResultColName)); params = Params.fromJson("{\"predResultColName\":\"\\\"f0\\\"\", \"predColName\":\"\\\"f0\\\"\"}"); try { params.get(predResultColName); Assert.fail("failure"); } catch (IllegalArgumentException ex) { Assert.assertTrue(ex.getMessage().startsWith("Duplicate parameters of predResultColName and predColName")); } }
Example 4
Source File: StandardScalerModelMapper.java From Alink with Apache License 2.0 | 6 votes |
public StandardScalerModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); this.selectedColNames = ImputerModelDataConverter.extractSelectedColNames(modelSchema); this.selectedColTypes = ImputerModelDataConverter.extractSelectedColTypes(modelSchema); this.selectedColIndices = TableUtil.findColIndicesWithAssert(dataSchema, selectedColNames); String[] outputColNames = params.get(SrtPredictMapperParams.OUTPUT_COLS); if (outputColNames == null) { outputColNames = selectedColNames; } this.predResultColsHelper = new OutputColsHelper(dataSchema, outputColNames, this.selectedColTypes, null); }
Example 5
Source File: ImputerModelMapper.java From Alink with Apache License 2.0 | 6 votes |
/** * Constructor. * @param modelSchema the model schema. * @param dataSchema the data schema. * @param params the params. */ public ImputerModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); String[] selectedColNames = ImputerModelDataConverter.extractSelectedColNames(modelSchema); TypeInformation[] selectedColTypes = ImputerModelDataConverter.extractSelectedColTypes(modelSchema); this.selectedColIndices = TableUtil.findColIndicesWithAssert(dataSchema, selectedColNames); String[] outputColNames = params.get(SrtPredictMapperParams.OUTPUT_COLS); if (outputColNames == null) { outputColNames = selectedColNames; } this.predictResultColsHelper = new OutputColsHelper(dataSchema, outputColNames, selectedColTypes, null); int length = selectedColTypes.length; this.type = new Type[length]; for (int i = 0; i < length; i++) { this.type[i] = Type.valueOf(selectedColTypes[i].getTypeClass().getSimpleName().toUpperCase()); } }
Example 6
Source File: AFTModelMapper.java From Alink with Apache License 2.0 | 5 votes |
/** * Constructor. * * @param modelSchema the model schema. * @param dataSchema the data schema. * @param params the params. */ public AFTModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); this.quantileProbabilities = params.get(AftRegPredictParams.QUANTILE_PROBABILITIES); if (null != params) { String vectorColName = params.get(LinearModelMapperParams.VECTOR_COL); if (null != vectorColName && vectorColName.length() != 0) { this.vectorColIndex = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), vectorColName); } } }
Example 7
Source File: SoftmaxModelMapper.java From Alink with Apache License 2.0 | 5 votes |
public SoftmaxModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); if (null != params) { String vectorColName = params.get(SoftmaxPredictParams.VECTOR_COL); if (null != vectorColName && vectorColName.length() != 0) { this.vectorColIndex = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), vectorColName); } } }
Example 8
Source File: ParamsTest.java From Alink with Apache License 2.0 | 5 votes |
@Test public void testContain4() { Params params = new Params() .set(HasEnumType.ENUM_TYPE, CalcType.aAA) .set(HasAppendType.APPEND_TYPE, "Dense"); CalcType type = params.get(HasEnumType.ENUM_TYPE); System.out.println(type); }
Example 9
Source File: VectorStandardScalerModelDataConverter.java From Alink with Apache License 2.0 | 5 votes |
/** * Deserialize the model data. * * @param meta The model meta data. * @param data The model concrete data. * @param additionData The additional data. * @return The model data used by mapper. */ @Override public Tuple4<Boolean, Boolean, double[], double[]> deserializeModel(Params meta, Iterable<String> data, Iterable<Row> additionData) { double[] means = JsonConverter.fromJson(data.iterator().next(), double[].class); double[] stdDevs = JsonConverter.fromJson(data.iterator().next(), double[].class); Boolean withMean = meta.get(VectorStandardTrainParams.WITH_MEAN); Boolean withStd = meta.get(VectorStandardTrainParams.WITH_STD); return Tuple4.of(withMean, withStd, means, stdDevs); }
Example 10
Source File: MISOMapper.java From Alink with Apache License 2.0 | 5 votes |
/** * Constructor. * * @param dataSchema input table schema. * @param params input parameters. */ public MISOMapper(TableSchema dataSchema, Params params) { super(dataSchema, params); String[] inputColNames = this.params.get(MISOMapperParams.SELECTED_COLS); this.colIndices = TableUtil.findColIndicesWithAssertAndHint(dataSchema.getFieldNames(), inputColNames); String outputColName = params.get(MISOMapperParams.OUTPUT_COL); String[] keepColNames = null; if (this.params.contains(MISOMapperParams.RESERVED_COLS)) { keepColNames = this.params.get(MISOMapperParams.RESERVED_COLS); } this.outputColsHelper = new OutputColsHelper(dataSchema, outputColName, initOutputColType(), keepColNames); }
Example 11
Source File: Preprocessing.java From Alink with Apache License 2.0 | 5 votes |
public static DataSet<Object[]> generateLabels( BatchOperator<?> input, Params params, boolean isRegression) { DataSet<Object[]> labels; if (!isRegression) { final String labelColName = params.get(HasLabelCol.LABEL_COL); DataSet<Row> labelDataSet = select(input, labelColName).getDataSet(); labels = distinctLabels(labelDataSet .map(new MapFunction<Row, Object>() { @Override public Object map(Row value) throws Exception { return value.getField(0); } }) ); } else { labels = MLEnvironmentFactory.get(input.getMLEnvironmentId()).getExecutionEnvironment().fromElements(1) .mapPartition(new MapPartitionFunction<Integer, Object[]>() { @Override public void mapPartition(Iterable<Integer> values, Collector<Object[]> out) throws Exception { //pass } }); } return labels; }
Example 12
Source File: TableBucketingSink.java From Alink with Apache License 2.0 | 5 votes |
public TableBucketingSink(String tableName, Params params, TableSchema schema, BaseDB db) { this.tableNamePrefix = tableName; this.types = schema.getFieldTypes(); this.colNames = schema.getFieldNames(); this.db = db; this.batchRolloverInterval = params.get(TableBucketingSinkParams.BATCH_ROLLOVER_INTERVAL); this.batchSize = params.get(TableBucketingSinkParams.BATCH_SIZE); if (batchSize > 0 && batchRolloverInterval < 0L) { batchRolloverInterval = Long.MAX_VALUE; } if (batchSize < 0 && batchRolloverInterval > 0L) { batchSize = Integer.MAX_VALUE; } }
Example 13
Source File: Preprocessing.java From Alink with Apache License 2.0 | 5 votes |
public static BatchOperator<?> castWeightCol( BatchOperator<?> input, Params params) { String weightCol = params.get(HasWeightColDefaultAsNull.WEIGHT_COL); if (weightCol == null) { return input; } return new NumericalTypeCastBatchOp() .setMLEnvironmentId(input.getMLEnvironmentId()) .setSelectedCols(weightCol) .setTargetType("DOUBLE") .linkFrom(input); }
Example 14
Source File: DocCountVectorizerTrainBatchOp.java From Alink with Apache License 2.0 | 4 votes |
public BuildDocCountModel(Params params) { this.featureType = params.get(DocHashCountVectorizerTrainParams.FEATURE_TYPE).name(); this.minTF = params.get(DocHashCountVectorizerTrainParams.MIN_TF); }
Example 15
Source File: DocHashCountVectorizerTrainBatchOp.java From Alink with Apache License 2.0 | 4 votes |
public BuildModel(Params params) { this.minDocFrequency = params.get(DocHashCountVectorizerTrainParams.MIN_DF); this.numFeatures = params.get(DocHashCountVectorizerTrainParams.NUM_FEATURES); this.featureType = params.get(DocHashCountVectorizerTrainParams.FEATURE_TYPE).name(); this.minTF = params.get(DocHashCountVectorizerTrainParams.MIN_TF); }
Example 16
Source File: FormatTransMapper.java From Alink with Apache License 2.0 | 4 votes |
public static Tuple2<FormatReader, String[]> initFormatReader(TableSchema dataSchema, Params params) { FormatReader formatReader; String[] fromColNames; FormatType fromFormat = params.get(FormatTransParams.FROM_FORMAT); switch (fromFormat) { case KV: String kvColName = params.get(FromKvParams.KV_COL); int kvColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), kvColName); formatReader = new KvReader( kvColIndex, params.get(FromKvParams.KV_COL_DELIMITER), params.get(FromKvParams.KV_VAL_DELIMITER) ); fromColNames = null; break; case CSV: String csvColName = params.get(FromCsvParams.CSV_COL); int csvColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), csvColName); TableSchema fromCsvSchema = CsvUtil.schemaStr2Schema(params.get(FromCsvParams.SCHEMA_STR)); formatReader = new CsvReader( csvColIndex, fromCsvSchema, params.get(FromCsvParams.CSV_FIELD_DELIMITER), params.get(FromCsvParams.QUOTE_CHAR) ); fromColNames = fromCsvSchema.getFieldNames(); break; case VECTOR: String vectorColName = params.get(FromVectorParams.VECTOR_COL); int vectorColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), vectorColName); if (params.contains(HasSchemaStr.SCHEMA_STR)) { formatReader = new VectorReader( vectorColIndex, CsvUtil.schemaStr2Schema(params.get(HasSchemaStr.SCHEMA_STR)) ); } else { formatReader = new VectorReader(vectorColIndex, null); } fromColNames = null; break; case JSON: String jsonColName = params.get(FromJsonParams.JSON_COL); int jsonColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), jsonColName); formatReader = new JsonReader(jsonColIndex); fromColNames = null; break; case COLUMNS: fromColNames = params.get(FromColumnsParams.SELECTED_COLS); if (null == fromColNames) { fromColNames = dataSchema.getFieldNames(); } int[] colIndices = TableUtil.findColIndicesWithAssertAndHint(dataSchema.getFieldNames(), fromColNames); formatReader = new ColumnsReader(colIndices, fromColNames); break; default: throw new IllegalArgumentException("Can not translate this type : " + fromFormat); } return new Tuple2<>(formatReader, fromColNames); }
Example 17
Source File: StringToColumnsMappers.java From Alink with Apache License 2.0 | 4 votes |
@Override protected StringParsers.StringParser getParser(String[] fieldNames, TypeInformation[] fieldTypes, Params params) { String colDelim = params.get(KvToColumnsParams.COL_DELIMITER); String valDelim = params.get(KvToColumnsParams.VAL_DELIMITER); return new StringParsers.KvParser(fieldNames, fieldTypes, colDelim, valDelim); }
Example 18
Source File: SOSImpl.java From Alink with Apache License 2.0 | 4 votes |
public SOSImpl(Params params) { perplexity = params.get(SosParams.PERPLEXITY); }
Example 19
Source File: TreeModelDataConverter.java From Alink with Apache License 2.0 | 4 votes |
@Override protected TreeModelDataConverter deserializeModel(Params meta, Iterable<String> iterable, Iterable<Object> distinctLabels) { // parseDense partition of categorical Partition stringIndexerModelPartition = meta.get( STRING_INDEXER_MODEL_PARTITION ); List<String> data = new ArrayList<>(); iterable.forEach(data::add); if (stringIndexerModelPartition.getF1() != stringIndexerModelPartition.getF0()) { stringIndexerModelSerialized = new ArrayList<>(); for (int i = stringIndexerModelPartition.getF0(); i < stringIndexerModelPartition.getF1(); ++i) { Object[] deserialized = JsonConverter.fromJson(data.get(i), Object[].class); stringIndexerModelSerialized.add( Row.of( ((Integer)deserialized[0]).longValue(), deserialized[1], deserialized[2] ) ); } } else { stringIndexerModelSerialized = null; } // toString partition of trees Partitions treesPartition = meta.get( TREE_PARTITIONS ); roots = treesPartition.getPartitions().stream() .map(x -> deserializeTree(data.subList(x.getF0(), x.getF1()))) .toArray(Node[]::new); this.meta = meta; List<Object> labelList = new ArrayList<>(); distinctLabels.forEach(labelList::add); this.labels = labelList.toArray(); return this; }
Example 20
Source File: DBSinkStreamOp.java From Alink with Apache License 2.0 | 4 votes |
public DBSinkStreamOp(BaseDB db, Params parameter) { super(AnnotationUtils.annotatedName(db.getClass()), db.getParams().clone().merge(parameter)); this.db = db; this.tableName = parameter.get(AnnotationUtils.tableAliasParamKey(db.getClass())); }