Java Code Examples for org.apache.flink.ml.api.misc.param.Params#set()
The following examples show how to use
org.apache.flink.ml.api.misc.param.Params#set() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ParamsTest.java From flink with Apache License 2.0 | 6 votes |
@Test public void getOptionalParam() { ParamInfo <String> key = ParamInfoFactory .createParamInfo("key", String.class) .setHasDefaultValue(null) .setDescription("") .build(); Params params = new Params(); Assert.assertNull(params.get(key)); String val = "3"; params.set(key, val); Assert.assertEquals(params.get(key), val); params.set(key, null); Assert.assertNull(params.get(key)); }
Example 2
Source File: BaseLinearModelTrainBatchOp.java From Alink with Apache License 2.0 | 6 votes |
@Override public void mapPartition(Iterable<Object> rows, Collector<Params> metas) throws Exception { Object[] labels = null; if (!this.isRegProc) { labels = orderLabels(rows); } Params meta = new Params(); meta.set(ModelParamName.MODEL_NAME, this.modelName); meta.set(ModelParamName.LINEAR_MODEL_TYPE, this.modelType); meta.set(ModelParamName.LABEL_VALUES, labels); meta.set(ModelParamName.HAS_INTERCEPT_ITEM, this.hasInterceptItem); meta.set(ModelParamName.VECTOR_COL_NAME, vectorColName); meta.set(LinearTrainParams.LABEL_COL, labelName); metas.collect(meta); }
Example 3
Source File: ClassificationEvaluationUtilTest.java From Alink with Apache License 2.0 | 6 votes |
@Test public void judgeEvaluationTypeTest(){ Params params = new Params() .set(HasPredictionDetailCol.PREDICTION_DETAIL_COL, "detail"); ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(params); Assert.assertEquals(type, ClassificationEvaluationUtil.Type.PRED_DETAIL); params.set(HasPredictionCol.PREDICTION_COL, "pred"); type = ClassificationEvaluationUtil.judgeEvaluationType(params); Assert.assertEquals(type, ClassificationEvaluationUtil.Type.PRED_DETAIL); params.remove(HasPredictionDetailCol.PREDICTION_DETAIL_COL); type = ClassificationEvaluationUtil.judgeEvaluationType(params); Assert.assertEquals(type, ClassificationEvaluationUtil.Type.PRED_RESULT); params.remove(HasPredictionCol.PREDICTION_COL); thrown.expect(RuntimeException.class); thrown.expectMessage("Error Input, must give either predictionCol or predictionDetailCol!"); ClassificationEvaluationUtil.judgeEvaluationType(params); }
Example 4
Source File: ParamsTest.java From flink with Apache License 2.0 | 6 votes |
@Test public void getOptionalParam() { ParamInfo <String> key = ParamInfoFactory .createParamInfo("key", String.class) .setHasDefaultValue(null) .setDescription("") .build(); Params params = new Params(); Assert.assertNull(params.get(key)); String val = "3"; params.set(key, val); Assert.assertEquals(params.get(key), val); params.set(key, null); Assert.assertNull(params.get(key)); }
Example 5
Source File: SelectMapperTest.java From Alink with Apache License 2.0 | 6 votes |
@Test public void testValueConstructionFunctions() throws Exception { TableSchema dataSchema = TableSchema.builder().fields( new String[] {"id", "name"}, new DataType[] {DataTypes.INT(), DataTypes.STRING()}).build(); Params params = new Params(); params.set(HasClause.CLAUSE, "ROW(1, 2, 3), ARRAY[1, 2, 3], MAP[1, 2, 3, 4]" ); SelectMapper selectMapper = new SelectMapper(dataSchema, params); selectMapper.open(); Row output = selectMapper.map(Row.of(1, "'abc'")); try { assertEquals(output.getArity(), 3); } finally { selectMapper.close(); } }
Example 6
Source File: ParamsTest.java From flink with Apache License 2.0 | 6 votes |
@Test public void getRequiredParam() { ParamInfo <String> labelWithRequired = ParamInfoFactory .createParamInfo("label", String.class) .setDescription("") .setRequired() .build(); Params params = new Params(); try { params.get(labelWithRequired); Assert.fail("failure"); } catch (IllegalArgumentException ex) { Assert.assertTrue(ex.getMessage().startsWith("Missing non-optional parameter")); } params.set(labelWithRequired, null); Assert.assertNull(params.get(labelWithRequired)); String val = "3"; params.set(labelWithRequired, val); Assert.assertEquals(params.get(labelWithRequired), val); }
Example 7
Source File: RegressionMetricsSummary.java From Alink with Apache License 2.0 | 6 votes |
@Override public RegressionMetrics toMetrics() { Params params = new Params(); params.set(RegressionMetrics.SST, ySum2Local - ySumLocal * ySumLocal / total); params.set(RegressionMetrics.SSE, sseLocal); params.set(RegressionMetrics.SSR, predSum2Local - 2 * ySumLocal * predSumLocal / total + ySumLocal * ySumLocal / total); params.set(RegressionMetrics.R2, 1 - params.get(RegressionMetrics.SSE) / params.get(RegressionMetrics.SST)); params.set(RegressionMetrics.R, Math.sqrt(params.get(RegressionMetrics.R2))); params.set(RegressionMetrics.MSE, params.get(RegressionMetrics.SSE) / total); params.set(RegressionMetrics.RMSE, Math.sqrt(params.get(RegressionMetrics.MSE))); params.set(RegressionMetrics.SAE, maeLocal); params.set(RegressionMetrics.MAE, params.get(RegressionMetrics.SAE) / total); params.set(RegressionMetrics.COUNT, (double)total); params.set(RegressionMetrics.MAPE, mapeLocal * 100 / total); params.set(RegressionMetrics.Y_MEAN, ySumLocal / total); params.set(RegressionMetrics.PREDICTION_MEAN, predSumLocal / total); params.set(RegressionMetrics.EXPLAINED_VARIANCE, params.get(RegressionMetrics.SSR) / total); return new RegressionMetrics(params); }
Example 8
Source File: SelectMapperTest.java From Alink with Apache License 2.0 | 6 votes |
@Test public void testCollectionFunctions() throws Exception { TableSchema dataSchema = TableSchema.builder().fields( new String[] {"id", "name"}, new DataType[] {DataTypes.INT(), DataTypes.STRING()}).build(); Params params = new Params(); params.set(HasClause.CLAUSE, "CARDINALITY(ARRAY[1,2,3])" + ", ARRAY[1,2,3][2]" + ", ELEMENT(ARRAY[2])" + ", CARDINALITY(MAP[1, 2, 3, 4])" + ", MAP[1, 2, 3, 4][3]" ); SelectMapper selectMapper = new SelectMapper(dataSchema, params); selectMapper.open(); Row expected = Row.of(3, 2, 2, 2, 4); Row output = selectMapper.map(Row.of(1, "'abc'")); try { assertEquals(expected, output); } finally { selectMapper.close(); } }
Example 9
Source File: MultiMetricsSummary.java From Alink with Apache License 2.0 | 5 votes |
/** * Calculate the detail info based on the confusion matrix. */ @Override public MultiClassMetrics toMetrics() { Params params = new Params(); ConfusionMatrix data = new ConfusionMatrix(matrix); params.set(MultiClassMetrics.PREDICT_LABEL_FREQUENCY, data.getPredictLabelFrequency()); params.set(MultiClassMetrics.PREDICT_LABEL_PROPORTION, data.getPredictLabelProportion()); for (ClassificationEvaluationUtil.Computations c : ClassificationEvaluationUtil.Computations.values()) { params.set(c.arrayParamInfo, ClassificationEvaluationUtil.getAllValues(c.computer, data)); } setClassificationCommonParams(params, data, labels); setLoglossParams(params, logLoss, total); return new MultiClassMetrics(params); }
Example 10
Source File: ParamsTest.java From flink with Apache License 2.0 | 5 votes |
@Test public void testValidator() { Params params = new Params(); ParamInfo<Integer> intParam = ParamInfoFactory.createParamInfo("a", Integer.class).setValidator(i -> i > 0).build(); params.set(intParam, 1); thrown.expect(RuntimeException.class); thrown.expectMessage("Setting a as a invalid value:0"); params.set(intParam, 0); }
Example 11
Source File: BaseLinearModelTrainBatchOp.java From Alink with Apache License 2.0 | 5 votes |
/** * Transform train data to Tuple3 format. * * @param in train data in row format. * @param params train parameters. * @param labelValues label values. * @param isRegProc is regression process or not. * @return Tuple3 format train data <weight, label, vector></>. */ private DataSet<Tuple3<Double, Double, Vector>> transform(BatchOperator in, Params params, DataSet<Object> labelValues, boolean isRegProc) { String[] featureColNames = params.get(LinearTrainParams.FEATURE_COLS); String labelName = params.get(LinearTrainParams.LABEL_COL); String weightColName = params.get(LinearTrainParams.WEIGHT_COL); String vectorColName = params.get(LinearTrainParams.VECTOR_COL); TableSchema dataSchema = in.getSchema(); if (null == featureColNames && null == vectorColName) { featureColNames = TableUtil.getNumericCols(dataSchema, new String[] {labelName}); params.set(LinearTrainParams.FEATURE_COLS, featureColNames); } int[] featureIndices = null; int labelIdx = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), labelName); if (featureColNames != null) { featureIndices = new int[featureColNames.length]; for (int i = 0; i < featureColNames.length; ++i) { int idx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), featureColNames[i]); featureIndices[i] = idx; TypeInformation type = in.getSchema().getFieldTypes()[idx]; Preconditions.checkState(TableUtil.isNumber(type), "linear algorithm only support numerical data type. type is : " + type); } } int weightIdx = weightColName != null ? TableUtil.findColIndexWithAssertAndHint(in.getColNames(), weightColName) : -1; int vecIdx = vectorColName != null ? TableUtil.findColIndexWithAssertAndHint(in.getColNames(), vectorColName) : -1; return in.getDataSet().map(new Transform(isRegProc, weightIdx, vecIdx, featureIndices, labelIdx)) .withBroadcastSet(labelValues, LABEL_VALUES); }
Example 12
Source File: FeedForwardTrainer.java From Alink with Apache License 2.0 | 5 votes |
/** * Train the network. * * @param data Training data, a dataset of tuples of (label, features). * @param optimizationParams Parameters for optimizations. * @return The model weights. */ public DataSet<DenseVector> train(DataSet<Tuple2<Double, DenseVector>> data, Params optimizationParams) { final Topology topology = this.topology; final int inputSize = this.inputSize; final int outputSize = this.outputSize; final boolean onehotLabel = this.onehotLabel; ParamInfo<Integer> NUM_SEARCH_STEP = ParamInfoFactory .createParamInfo("numSearchStep", Integer.class) .setDescription("num search step") .setRequired() .build(); DataSet<DenseVector> initCoef = initModel(data, this.topology); DataSet<Tuple3<Double, Double, Vector>> trainData = stack(data, blockSize, inputSize, outputSize, onehotLabel); optimizationParams.set(NUM_SEARCH_STEP, 3); final AnnObjFunc annObjFunc = new AnnObjFunc(topology, inputSize, outputSize, onehotLabel, optimizationParams); // We always use L-BFGS to train the network. Optimizer optimizer = new Lbfgs( data.getExecutionEnvironment().fromElements(annObjFunc), trainData, BatchOperator .getExecutionEnvironmentFromDataSets(data) .fromElements(inputSize), optimizationParams ); optimizer.initCoefWith(initCoef); return optimizer.optimize().map(new MapFunction<Tuple2<DenseVector, double[]>, DenseVector>() { @Override public DenseVector map(Tuple2<DenseVector, double[]> value) throws Exception { return value.f0; } }); }
Example 13
Source File: SelectMapperTest.java From Alink with Apache License 2.0 | 5 votes |
@Test public void testStringFunctions() throws Exception { TableSchema dataSchema = TableSchema.builder().fields( new String[] {"id", "name"}, new DataType[] {DataTypes.INT(), DataTypes.STRING()}).build(); Params params = new Params(); params.set(HasClause.CLAUSE, "name || name, CHAR_LENGTH(name), CHARACTER_LENGTH(name), UPPER(name), LOWER(name), POSITION(name IN name)," + "TRIM('a' FROM name), REPEAT(name, 3)" + ", OVERLAY('This is an old string' PLACING ' new' FROM 10 FOR 5)" + ", SUBSTRING(name FROM 2)" + ", REPLACE('hello world', 'world', 'flink')" + ", INITCAP(name)" + ", FROM_BASE64('aGVsbG8gd29ybGQ=')" + ", TO_BASE64('hello world')" + ", LPAD('hi',4,'??')" + ", RPAD('hi',4,'??')" + ", REGEXP_REPLACE('foobar', 'oo|ar', '')" + ", REGEXP_EXTRACT('foothebar', 'foo(.*?)(bar)', 2)" + ", LTRIM(' This is a test String.')" + ", RTRIM('This is a test String. ')" ); SelectMapper selectMapper = new SelectMapper(dataSchema, params); selectMapper.open(); Row expected = Row.of("'abc''abc'", 5, 5, "'ABC'", "'abc'", 1, "'abc'", "'abc''abc''abc'", "This is a new string", "abc'", "hello flink", "'Abc'", "hello world", "aGVsbG8gd29ybGQ=", "??hi", "hi??", "fb", "bar", "This is a test String.", "This is a test String."); Row output = selectMapper.map(Row.of(1, "'abc'")); assertEquals(expected.getArity(), output.getArity()); try { assertEquals(expected, output); } finally { selectMapper.close(); } }
Example 14
Source File: ClusterEvaluationUtil.java From Alink with Apache License 2.0 | 5 votes |
@Override public Params map(BaseMetricsSummary t) throws Exception { Params params = t.toMetrics().getParams(); List<Tuple1<Double>> silhouetteCoefficient = getRuntimeContext().getBroadcastVariable( EvalClusterBatchOp.SILHOUETTE_COEFFICIENT); params.set(ClusterMetrics.SILHOUETTE_COEFFICIENT, silhouetteCoefficient.get(0).f0 / params.get(ClusterMetrics.COUNT)); return params; }
Example 15
Source File: LinearModelDataConverter.java From Alink with Apache License 2.0 | 5 votes |
private Params getMetaInfo(LinearModelData data) { Params meta = new Params(); meta.set(ModelParamName.MODEL_NAME, data.modelName); meta.set(ModelParamName.HAS_INTERCEPT_ITEM, data.hasInterceptItem); meta.set(ModelParamName.LINEAR_MODEL_TYPE, data.linearModelType); if (data.vectorColName != null) { meta.set(HasVectorCol.VECTOR_COL, data.vectorColName); meta.set(ModelParamName.VECTOR_SIZE, data.vectorSize); } meta.set(HasLabelCol.LABEL_COL, data.labelName); return meta; }
Example 16
Source File: ImputerModelDataConverter.java From Alink with Apache License 2.0 | 5 votes |
/** * Serialize the model to "Tuple3<Params, List<String>, List<Row>>" * * @param modelData The model data to serialize. * @return The serialization result. */ @Override public Tuple3<Params, Iterable<String>, Iterable<Row>> serializeModel(Tuple3<Strategy, TableSummary, String> modelData) { Strategy strategy = modelData.f0; TableSummary summary = modelData.f1; String fillValue = modelData.f2; double[] values = null; Params meta = new Params() .set(STRATEGY, strategy) .set(SELECTED_COLS, selectedColNames); switch (strategy) { case MIN: values = new double[selectedColNames.length]; for (int i = 0; i < selectedColNames.length; i++) { values[i] = summary.min(selectedColNames[i]); } break; case MAX: values = new double[selectedColNames.length]; for (int i = 0; i < selectedColNames.length; i++) { values[i] = summary.max(selectedColNames[i]); } break; case MEAN: values = new double[selectedColNames.length]; for (int i = 0; i < selectedColNames.length; i++) { values[i] = summary.mean(selectedColNames[i]); } break; default: meta.set(FILL_VALUE, fillValue); } List<String> data = new ArrayList<>(); data.add(JsonConverter.toJson(values)); return Tuple3.of(meta, data, new ArrayList<>()); }
Example 17
Source File: BaseLinearModelTrainBatchOp.java From Alink with Apache License 2.0 | 5 votes |
/** * Build model data. * * @param meta meta info. * @param featureNames feature column names. * @param labelType label type. * @param meanVar mean and variance of vector. * @param hasIntercept has interception or not. * @param standardization do standardization or not. * @param coefVector coefficient vector. * @return linear mode data. */ public static LinearModelData buildLinearModelData(Params meta, String[] featureNames, TypeInformation labelType, DenseVector[] meanVar, boolean hasIntercept, boolean standardization, Tuple2<DenseVector, double[]> coefVector) { if (!(LinearModelType.AFT.equals(meta.get(ModelParamName.LINEAR_MODEL_TYPE)))) { modifyMeanVar(standardization, meanVar); } meta.set(ModelParamName.VECTOR_SIZE, coefVector.f0.size() - (meta.get(ModelParamName.HAS_INTERCEPT_ITEM) ? 1 : 0) - (LinearModelType.AFT.equals(meta.get(ModelParamName.LINEAR_MODEL_TYPE).toString()) ? 1 : 0)); if (!(LinearModelType.AFT.equals(meta.get(ModelParamName.LINEAR_MODEL_TYPE)))) { if (standardization) { int n = meanVar[0].size(); if (hasIntercept) { double sum = 0.0; for (int i = 0; i < n; ++i) { sum += coefVector.f0.get(i + 1) * meanVar[0].get(i) / meanVar[1].get(i); coefVector.f0.set(i + 1, coefVector.f0.get(i + 1) / meanVar[1].get(i)); } coefVector.f0.set(0, coefVector.f0.get(0) - sum); } else { for (int i = 0; i < n; ++i) { coefVector.f0.set(i, coefVector.f0.get(i) / meanVar[1].get(i)); } } } } LinearModelData modelData = new LinearModelData(labelType, meta, featureNames, coefVector.f0); modelData.lossCurve = coefVector.f1; return modelData; }
Example 18
Source File: OutputModel.java From Alink with Apache License 2.0 | 5 votes |
@Override public List <Row> calc(ComContext context) { if (context.getTaskId() != 0) { return null; } // get the coefficient of min loss. Tuple2 <DenseVector, double[]> minCoef = context.getObj(OptimVariable.minCoef); double[] lossCurve = context.getObj(OptimVariable.lossCurve); int effectiveSize = lossCurve.length; for (int i = 0; i < lossCurve.length; ++i) { if (Double.isInfinite(lossCurve[i])) { effectiveSize = i; break; } } double[] effectiveCurve = new double[effectiveSize]; System.arraycopy(lossCurve, 0, effectiveCurve, 0, effectiveSize); Params params = new Params(); for (int i = 0; i < minCoef.f0.size(); ++i) { if (Double.isNaN(minCoef.f0.get(i)) || Double.isInfinite(minCoef.f0.get(i))) { throw new RuntimeException("Optimization result has NAN or infinite value, coefficient is invalid"); } } params.set(ModelParamName.COEF, minCoef.f0); params.set(ModelParamName.LOSS_CURVE, effectiveCurve); List <Row> model = new ArrayList <>(1); model.add(Row.of(params.toJson())); return model; }
Example 19
Source File: TreeInitObj.java From Alink with Apache License 2.0 | 4 votes |
@Override public void calc(ComContext context) { if (context.getStepNo() != 1) { return; } List <Row> dataRows = context.getObj("treeInput"); List <Row> quantileModel = context.getObj("quantileModel"); List <Row> stringIndexerModel = context.getObj("stringIndexerModel"); List<Object[]> labels = context.getObj("labels"); int nLocalRow = dataRows == null ? 0 : dataRows.size(); Params localParams = params.clone(); localParams.set(TASK_ID, context.getTaskId()); localParams.set(NUM_OF_SUBTASKS, context.getNumTask()); localParams.set(N_LOCAL_ROW, nLocalRow); QuantileDiscretizerModelDataConverter quantileDiscretizerModel = initialMapping(quantileModel); List<String> lookUpColNames = new ArrayList<>(); if (params.get(HasCategoricalCols.CATEGORICAL_COLS) != null) { lookUpColNames.addAll(Arrays.asList(params.get(HasCategoricalCols.CATEGORICAL_COLS))); } Map<String, Integer> categoricalColsSize = TreeUtil.extractCategoricalColsSize( stringIndexerModel, lookUpColNames.toArray(new String[0])); if (!Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) { categoricalColsSize.put(params.get(HasLabelCol.LABEL_COL), labels.get(0).length); } FeatureMeta[] featureMetas = TreeUtil.getFeatureMeta( params.get(HasFeatureCols.FEATURE_COLS), categoricalColsSize ); FeatureMeta labelMeta = TreeUtil.getLabelMeta( params.get(HasLabelCol.LABEL_COL), params.get(HasFeatureCols.FEATURE_COLS).length, categoricalColsSize); TreeObj treeObj; if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) { treeObj = new RegObj(localParams, quantileDiscretizerModel, featureMetas, labelMeta); } else { treeObj = new ClassifierObj(localParams, quantileDiscretizerModel, featureMetas, labelMeta); } int nFeatureCol = localParams.get(RandomForestTrainParams.FEATURE_COLS).length; int[] data = new int[nFeatureCol * nLocalRow]; double[] regLabels = null; int[] classifyLabels = null; if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) { regLabels = new double[nLocalRow]; } else { classifyLabels = new int[nLocalRow]; } int agg = 0; for (int iter = 0; iter < nLocalRow; ++iter) { for (int i = 0; i < nFeatureCol; ++i) { data[i * nLocalRow + agg] = (int) dataRows.get(iter).getField(i); } if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) { regLabels[agg] = (double) dataRows.get(iter).getField(nFeatureCol); } else { classifyLabels[agg] = (int) dataRows.get(iter).getField(nFeatureCol); } agg++; } treeObj.setFeatures(data); if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) { treeObj.setLabels(regLabels); } else { treeObj.setLabels(classifyLabels); } double[] histBuffer = new double[treeObj.getMaxHistBufferSize()]; context.putObj("allReduce", histBuffer); treeObj.setHist(histBuffer); treeObj.initialRoot(); context.putObj("treeObj", treeObj); }
Example 20
Source File: ClassificationEvaluationUtil.java From Alink with Apache License 2.0 | 4 votes |
static void setLoglossParams(Params params, double logLoss, long total) { if (logLoss >= 0) { params.set(BaseSimpleClassifierMetrics.LOG_LOSS, logLoss / total); } }