Java Code Examples for org.apache.flink.ml.api.misc.param.Params#contains()
The following examples show how to use
org.apache.flink.ml.api.misc.param.Params#contains() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MySqlDB.java From Alink with Apache License 2.0 | 6 votes |
@Override public Table getStreamTable(String tableName, Params params, Long sessionId) throws Exception { if (!params.contains(MySqlSourceParams.SCHEMA_STR)) { return super.getStreamTable(tableName, params, sessionId); } else { TableSchema schema = CsvUtil.schemaStr2Schema(params.get(MySqlSourceParams.SCHEMA_STR)); JDBCInputFormat inputFormat = JDBCInputFormat.buildJDBCInputFormat() .setUsername(getUserName()) .setPassword(getPassword()) .setDrivername(getDriverName()) .setDBUrl(getDbUrl()) .setQuery("select * from " + tableName) .setRowTypeInfo(new RowTypeInfo(schema.getFieldTypes(), schema.getFieldNames())) .finish(); return DataStreamConversionUtil.toTable( sessionId, MLEnvironmentFactory.get(sessionId).getStreamExecutionEnvironment().createInput(inputFormat), schema.getFieldNames(), schema.getFieldTypes()); } }
Example 2
Source File: MySqlDB.java From Alink with Apache License 2.0 | 6 votes |
@Override public Table getBatchTable(String tableName, Params params, Long sessionId) throws Exception { if (!params.contains(MySqlSourceParams.SCHEMA_STR)) { return super.getBatchTable(tableName, params, sessionId); } else { TableSchema schema = CsvUtil.schemaStr2Schema(params.get(MySqlSourceParams.SCHEMA_STR)); JDBCInputFormat inputFormat = JDBCInputFormat.buildJDBCInputFormat() .setUsername(getUserName()) .setPassword(getPassword()) .setDrivername(getDriverName()) .setDBUrl(getDbUrl()) .setQuery("select * from " + tableName) .setRowTypeInfo(new RowTypeInfo(schema.getFieldTypes(), schema.getFieldNames())) .finish(); return DataSetConversionUtil.toTable(sessionId, MLEnvironmentFactory.get(sessionId).getExecutionEnvironment().createInput(inputFormat), schema.getFieldNames(), schema.getFieldTypes()); } }
Example 3
Source File: LinearModelDataConverter.java From Alink with Apache License 2.0 | 6 votes |
/** * Deserialize the model data. * * @param meta The model meta data. * @param data The model concrete data. * @param distinctLabels All the label values in the data set. * @return The deserialized model data. */ @Override public LinearModelData deserializeModel(Params meta, Iterable<String> data, Iterable<Object> distinctLabels) { LinearModelData modelData = new LinearModelData(); if (meta.contains(ModelParamName.LABEL_VALUES)) { modelData.labelValues = FeatureLabelUtil.recoverLabelType(meta.get(ModelParamName.LABEL_VALUES), this.labelType); } setMetaInfo(meta, modelData); if (distinctLabels != null) { List<Object> labelList = new ArrayList<>(); distinctLabels.forEach(labelList::add); modelData.labelValues = labelList.toArray(); } setModelData(JsonConverter.fromJson(data.iterator().next(), ModelData.class), modelData); return modelData; }
Example 4
Source File: CorrelationDataConverter.java From Alink with Apache License 2.0 | 6 votes |
/** * Deserialize the model from "Params meta" and "List<String> data". */ @Override public CorrelationResult deserializeModel(Params meta, Iterable<String> data) { String[] colNames = null; if (meta.contains(CorrelationParams.SELECTED_COLS)) { colNames = meta.get(CorrelationParams.SELECTED_COLS); } DenseMatrix matrix = null; int i = 0; for (String vecStr : data) { DenseVector vec = (DenseVector) VectorUtil.getVector(vecStr); if (matrix == null) { matrix = new DenseMatrix(vec.size(), vec.size()); } for (int j = 0; j < vec.size(); j++) { matrix.set(i, j, vec.get(j)); } i++; } CorrelationResult modelData = new CorrelationResult(matrix, colNames); return modelData; }
Example 5
Source File: OneVsRestModelMapper.java From Alink with Apache License 2.0 | 6 votes |
public OneVsRestModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); String predResultColName = params.get(OneVsRestPredictParams.PREDICTION_COL); String[] keepColNames = params.get(OneVsRestPredictParams.RESERVED_COLS); this.predDetail = params.contains(OneVsRestPredictParams.PREDICTION_DETAIL_COL); int numModelCols = modelSchema.getFieldNames().length; TypeInformation labelType = modelSchema.getFieldTypes()[numModelCols - 1]; if (predDetail) { String predDetailColName = params.get(OneVsRestPredictParams.PREDICTION_DETAIL_COL); outputColsHelper = new OutputColsHelper(dataSchema, new String[]{predResultColName, predDetailColName}, new TypeInformation[]{labelType, Types.STRING}, keepColNames); } else { outputColsHelper = new OutputColsHelper(dataSchema, predResultColName, labelType, keepColNames); } this.binClsPredParams = params.clone(); this.binClsPredParams.set(OneVsRestPredictParams.RESERVED_COLS, new String[0]); this.binClsPredParams.set(OneVsRestPredictParams.PREDICTION_COL, "pred_result"); this.binClsPredParams.set(OneVsRestPredictParams.PREDICTION_DETAIL_COL, "pred_detail"); }
Example 6
Source File: KMeansModelMapper.java From Alink with Apache License 2.0 | 6 votes |
public KMeansModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) { super(modelSchema, dataSchema, params); String[] reservedColNames = this.params.get(KMeansPredictParams.RESERVED_COLS); String predResultColName = this.params.get(KMeansPredictParams.PREDICTION_COL); isPredDetail = params.contains(KMeansPredictParams.PREDICTION_DETAIL_COL); isPredDistance = params.contains(KMeansPredictParams.PREDICTION_DISTANCE_COL); List<String> outputCols = new ArrayList<>(); List<TypeInformation> outputTypes = new ArrayList<>(); outputCols.add(predResultColName); outputTypes.add(Types.LONG); if (isPredDetail) { outputCols.add(params.get(KMeansPredictParams.PREDICTION_DETAIL_COL)); outputTypes.add(Types.STRING); } if (isPredDistance) { outputCols.add(params.get(KMeansPredictParams.PREDICTION_DISTANCE_COL)); outputTypes.add(Types.DOUBLE); } this.outputColsHelper = new OutputColsHelper(dataSchema, outputCols.toArray(new String[0]), outputTypes.toArray(new TypeInformation[0]), reservedColNames); }
Example 7
Source File: Preprocessing.java From Alink with Apache License 2.0 | 5 votes |
public static BatchOperator<?> castLabel( BatchOperator<?> input, Params params, DataSet<Object[]> labels, boolean isRegression) { String[] inputColNames = input.getColNames(); if (!isRegression) { final String labelColName = params.get(HasLabelCol.LABEL_COL); final TypeInformation<?>[] types = input.getColTypes(); input = new DataSetWrapperBatchOp( findIndexOfLabel( input.getDataSet(), labels, TableUtil.findColIndex(inputColNames, labelColName) ), input.getColNames(), IntStream.range(0, input.getColTypes().length) .mapToObj(x -> x == TableUtil.findColIndex(inputColNames, labelColName) ? Types.INT : types[x]) .toArray(TypeInformation[]::new) ).setMLEnvironmentId(input.getMLEnvironmentId()); } else { if (params.contains(HasLabelCol.LABEL_COL)) { input = new NumericalTypeCastBatchOp() .setMLEnvironmentId(input.getMLEnvironmentId()) .setSelectedCols(params.get(HasLabelCol.LABEL_COL)) .setTargetType("DOUBLE") .linkFrom(input); } } return input; }
Example 8
Source File: Preprocessing.java From Alink with Apache License 2.0 | 5 votes |
public static BatchOperator<?> generateStringIndexerModel(BatchOperator<?> input, Params params) { String[] categoricalColNames = null; if (params.contains(HasCategoricalCols.CATEGORICAL_COLS)) { categoricalColNames = params.get(HasCategoricalCols.CATEGORICAL_COLS); } BatchOperator<?> stringIndexerModel; if (categoricalColNames == null || categoricalColNames.length == 0) { MultiStringIndexerModelDataConverter emptyModel = new MultiStringIndexerModelDataConverter(); stringIndexerModel = new DataSetWrapperBatchOp( MLEnvironmentFactory .get(input.getMLEnvironmentId()) .getExecutionEnvironment() .fromElements(1) .mapPartition(new MapPartitionFunction<Integer, Row>() { @Override public void mapPartition(Iterable<Integer> values, Collector<Row> out) throws Exception { //pass } }), emptyModel.getModelSchema().getFieldNames(), emptyModel.getModelSchema().getFieldTypes() ).setMLEnvironmentId(input.getMLEnvironmentId()); } else { stringIndexerModel = new MultiStringIndexerTrainBatchOp() .setMLEnvironmentId(input.getMLEnvironmentId()) .setSelectedCols(categoricalColNames) .linkFrom(input); } return stringIndexerModel; }
Example 9
Source File: TreeUtil.java From Alink with Apache License 2.0 | 5 votes |
public static String[] trainColNames(Params params) { ArrayList<String> colNames = new ArrayList<>( Arrays.asList(params.get(HasFeatureCols.FEATURE_COLS)) ); if (params.contains(HasLabelCol.LABEL_COL)) { colNames.add(params.get(HasLabelCol.LABEL_COL)); } if (params.get(HasWeightColDefaultAsNull.WEIGHT_COL) != null) { colNames.add(params.get(HasWeightColDefaultAsNull.WEIGHT_COL)); } return colNames.toArray(new String[0]); }
Example 10
Source File: BaseSourceBatchOp.java From Alink with Apache License 2.0 | 5 votes |
public static BaseSourceBatchOp of(Params params) throws Exception { if (params.contains(HasIoType.IO_TYPE) && params.get(HasIoType.IO_TYPE).equals(IO_TYPE) && params.contains(HasIoName.IO_NAME)) { if (BaseDB.isDB(params)) { return new DBSourceBatchOp(BaseDB.of(params), params); } else if (params.contains(HasIoName.IO_NAME)) { String name = params.get(HasIoName.IO_NAME); return (BaseSourceBatchOp) AnnotationUtils.createOp(name, IO_TYPE, params); } } throw new RuntimeException("Parameter Error."); }
Example 11
Source File: BaseSinkBatchOp.java From Alink with Apache License 2.0 | 5 votes |
public static BaseSinkBatchOp of(Params params) throws Exception { if (params.contains(HasIoType.IO_TYPE) && params.get(HasIoType.IO_TYPE).equals(IO_TYPE) && params.contains(HasIoName.IO_NAME)) { if (BaseDB.isDB(params)) { return new DBSinkBatchOp(BaseDB.of(params), params); } else if (params.contains(HasIoName.IO_NAME)) { String name = params.get(HasIoName.IO_NAME); return (BaseSinkBatchOp) AnnotationUtils.createOp(name, IO_TYPE, params); } } throw new RuntimeException("Parameter Error."); }
Example 12
Source File: BaseLinearModelTrainBatchOp.java From Alink with Apache License 2.0 | 5 votes |
/** * optimize linear problem * * @param params parameters need by optimizer. * @param vectorSize vector size. * @param trainData train Data. * @param modelType linear model type. * @param session machine learning environment * @return coefficient of linear problem. */ public static DataSet<Tuple2<DenseVector, double[]>> optimize(Params params, DataSet<Integer> vectorSize, DataSet<Tuple3<Double, Double, Vector>> trainData, LinearModelType modelType, MLEnvironment session) { boolean hasInterceptItem = params.get(LinearTrainParams.WITH_INTERCEPT); String[] featureColNames = params.get(LinearTrainParams.FEATURE_COLS); String vectorColName = params.get(LinearTrainParams.VECTOR_COL); if ("".equals(vectorColName)) { vectorColName = null; } if (org.apache.commons.lang3.ArrayUtils.isEmpty(featureColNames)) { featureColNames = null; } DataSet<Integer> coefficientDim; if (vectorColName != null && vectorColName.length() != 0) { coefficientDim = session.getExecutionEnvironment().fromElements(0) .map(new DimTrans(hasInterceptItem, modelType)) .withBroadcastSet(vectorSize, VECTOR_SIZE); } else { coefficientDim = session.getExecutionEnvironment().fromElements(featureColNames.length + (hasInterceptItem ? 1 : 0) + (modelType.equals(LinearModelType.AFT) ? 1 : 0)); } // Loss object function DataSet<OptimObjFunc> objFunc = session.getExecutionEnvironment() .fromElements(getObjFunction(modelType, params)); if (params.contains(LinearTrainParams.OPTIM_METHOD)) { LinearTrainParams.OptimMethod method = params.get(LinearTrainParams.OPTIM_METHOD); return OptimizerFactory.create(objFunc, trainData, coefficientDim, params, method).optimize(); } else if (params.get(HasL1.L_1) > 0) { return new Owlqn(objFunc, trainData, coefficientDim, params).optimize(); } else { return new Lbfgs(objFunc, trainData, coefficientDim, params).optimize(); } }
Example 13
Source File: LinearModelData.java From Alink with Apache License 2.0 | 5 votes |
/** * Construct function. * @param labelType label Type. * @param meta meta information of model. * @param featureNames the feature column names. * @param coefVector */ public LinearModelData(TypeInformation labelType, Params meta, String[] featureNames, DenseVector coefVector) { this.labelType = labelType; this.coefVector = coefVector; this.featureNames = featureNames; if (meta.contains(ModelParamName.LABEL_VALUES)) { this.labelValues = FeatureLabelUtil.recoverLabelType(meta.get(ModelParamName.LABEL_VALUES), this.labelType); } setMetaInfo(meta); }
Example 14
Source File: LinearModelData.java From Alink with Apache License 2.0 | 5 votes |
public void setMetaInfo(Params meta) { this.modelName = meta.get(ModelParamName.MODEL_NAME); this.linearModelType = meta.contains(ModelParamName.LINEAR_MODEL_TYPE) ? meta.get(ModelParamName.LINEAR_MODEL_TYPE) : null; this.hasInterceptItem = meta.contains(ModelParamName.HAS_INTERCEPT_ITEM) ? meta.get( ModelParamName.HAS_INTERCEPT_ITEM) : true; this.vectorSize = meta.contains(ModelParamName.VECTOR_SIZE) ? meta.get(ModelParamName.VECTOR_SIZE) : 0; this.vectorColName = meta.contains(HasVectorCol.VECTOR_COL) ? meta.get(HasVectorCol.VECTOR_COL) : null; }
Example 15
Source File: LinearModelData.java From Alink with Apache License 2.0 | 5 votes |
private List <Object> recoverLabelsFromOldFormatModel(Params meta) { this.labelType = FlinkTypeConverter.getFlinkType(meta.get(ModelParamName.LABEL_TYPE_NAME)); List <Object> labels = new ArrayList<>(); if (meta.contains(ModelParamName.LABEL_VALUES)) { Object[] labelValues = FeatureLabelUtil.recoverLabelType(meta.get(ModelParamName.LABEL_VALUES), labelType); labels = Arrays.asList(labelValues); } return labels; }
Example 16
Source File: BaseSourceStreamOp.java From Alink with Apache License 2.0 | 5 votes |
public static BaseSourceStreamOp of(Params params) throws Exception { if (params.contains(HasIoType.IO_TYPE) && params.get(HasIoType.IO_TYPE).equals(IO_TYPE) && params.contains(HasIoName.IO_NAME)) { if (BaseDB.isDB(params)) { return new DBSourceStreamOp(BaseDB.of(params), params); } else if (params.contains(HasIoName.IO_NAME)) { String name = params.get(HasIoName.IO_NAME); return (BaseSourceStreamOp) AnnotationUtils.createOp(name, IO_TYPE, params); } } throw new RuntimeException("Parameter Error."); }
Example 17
Source File: LinearModelDataConverter.java From Alink with Apache License 2.0 | 5 votes |
/** * Set the meta information into the linear model data. */ private void setMetaInfo(Params meta, LinearModelData data) { data.modelName = meta.get(ModelParamName.MODEL_NAME); data.linearModelType = meta.contains(ModelParamName.LINEAR_MODEL_TYPE) ? meta.get(ModelParamName.LINEAR_MODEL_TYPE) : null; data.hasInterceptItem = meta.contains(ModelParamName.HAS_INTERCEPT_ITEM) ? meta.get(ModelParamName.HAS_INTERCEPT_ITEM) : true; data.vectorSize = meta.contains(ModelParamName.VECTOR_SIZE) ? meta.get(ModelParamName.VECTOR_SIZE) : 0; data.vectorColName = meta.contains(HasVectorCol.VECTOR_COL) ? meta.get(HasVectorCol.VECTOR_COL) : null; data.labelName = meta.contains(HasLabelCol.LABEL_COL) ? meta.get(HasLabelCol.LABEL_COL) : null; }
Example 18
Source File: BaseSinkStreamOp.java From Alink with Apache License 2.0 | 5 votes |
public static BaseSinkStreamOp of(Params params) throws Exception { if (params.contains(HasIoType.IO_TYPE) && params.get(HasIoType.IO_TYPE).equals(IO_TYPE) && params.contains(HasIoName.IO_NAME)) { if (BaseDB.isDB(params)) { return new DBSinkStreamOp(BaseDB.of(params), params); } else if (params.contains(HasIoName.IO_NAME)) { String name = params.get(HasIoName.IO_NAME); return (BaseSinkStreamOp) AnnotationUtils.createOp(name, IO_TYPE, params); } } throw new RuntimeException("Parameter Error."); }
Example 19
Source File: FormatTransMapper.java From Alink with Apache License 2.0 | 4 votes |
public static Tuple2<FormatReader, String[]> initFormatReader(TableSchema dataSchema, Params params) { FormatReader formatReader; String[] fromColNames; FormatType fromFormat = params.get(FormatTransParams.FROM_FORMAT); switch (fromFormat) { case KV: String kvColName = params.get(FromKvParams.KV_COL); int kvColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), kvColName); formatReader = new KvReader( kvColIndex, params.get(FromKvParams.KV_COL_DELIMITER), params.get(FromKvParams.KV_VAL_DELIMITER) ); fromColNames = null; break; case CSV: String csvColName = params.get(FromCsvParams.CSV_COL); int csvColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), csvColName); TableSchema fromCsvSchema = CsvUtil.schemaStr2Schema(params.get(FromCsvParams.SCHEMA_STR)); formatReader = new CsvReader( csvColIndex, fromCsvSchema, params.get(FromCsvParams.CSV_FIELD_DELIMITER), params.get(FromCsvParams.QUOTE_CHAR) ); fromColNames = fromCsvSchema.getFieldNames(); break; case VECTOR: String vectorColName = params.get(FromVectorParams.VECTOR_COL); int vectorColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), vectorColName); if (params.contains(HasSchemaStr.SCHEMA_STR)) { formatReader = new VectorReader( vectorColIndex, CsvUtil.schemaStr2Schema(params.get(HasSchemaStr.SCHEMA_STR)) ); } else { formatReader = new VectorReader(vectorColIndex, null); } fromColNames = null; break; case JSON: String jsonColName = params.get(FromJsonParams.JSON_COL); int jsonColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), jsonColName); formatReader = new JsonReader(jsonColIndex); fromColNames = null; break; case COLUMNS: fromColNames = params.get(FromColumnsParams.SELECTED_COLS); if (null == fromColNames) { fromColNames = dataSchema.getFieldNames(); } int[] colIndices = TableUtil.findColIndicesWithAssertAndHint(dataSchema.getFieldNames(), fromColNames); formatReader = new ColumnsReader(colIndices, fromColNames); break; default: throw new IllegalArgumentException("Can not translate this type : " + fromFormat); } return new Tuple2<>(formatReader, fromColNames); }
Example 20
Source File: StringIndexerModel.java From Alink with Apache License 2.0 | 4 votes |
public StringIndexerModel(Params params) { super(StringIndexerModelMapper::new, params); if (params.contains(StringIndexer.MODEL_NAME)) { registerModel(params.get(StringIndexer.MODEL_NAME), this); } }