org.dmg.pmml.DataDictionary Java Examples

The following examples show how to use org.dmg.pmml.DataDictionary. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TargetUtilTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 6 votes vote down vote up
static
private TreeModelEvaluator createTreeModelEvaluator(MiningFunction miningFunction, MathContext mathContext, Target target){
	Node root = new LeafNode(null, False.INSTANCE);

	Targets targets = new Targets()
		.addTargets(target);

	TreeModel treeModel = new TreeModel(miningFunction, new MiningSchema(), root)
		.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT)
		.setMathContext(mathContext)
		.setTargets(targets);

	PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), new DataDictionary())
		.addModels(treeModel);

	ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);

	return (TreeModelEvaluator)modelEvaluatorBuilder.build();
}
 
Example #2
Source File: KMeansPMMLUtils.java    From oryx with Apache License 2.0 6 votes vote down vote up
/**
 * Validates that the encoded PMML model received matches expected schema.
 *
 * @param pmml {@link PMML} encoding of KMeans Clustering
 * @param schema expected schema attributes of KMeans Clustering
 */
public static void validatePMMLVsSchema(PMML pmml, InputSchema schema) {
  List<Model> models = pmml.getModels();
  Preconditions.checkArgument(models.size() == 1,
      "Should have exactly one model, but had %s", models.size());

  Model model = models.get(0);
  Preconditions.checkArgument(model instanceof ClusteringModel);
  Preconditions.checkArgument(model.getMiningFunction() == MiningFunction.CLUSTERING);

  DataDictionary dictionary = pmml.getDataDictionary();
  Preconditions.checkArgument(
      schema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(dictionary)),
      "Feature names in schema don't match names in PMML");

  MiningSchema miningSchema = model.getMiningSchema();
  Preconditions.checkArgument(schema.getFeatureNames().equals(
      AppPMMLUtils.getFeatureNames(miningSchema)));

}
 
Example #3
Source File: AppPMMLUtilsTest.java    From oryx with Apache License 2.0 6 votes vote down vote up
@Test
public void testBuildDataDictionary() {
  Map<Integer,Collection<String>> distinctValues = new HashMap<>();
  distinctValues.put(1, Arrays.asList("one", "two", "three", "four", "five"));
  CategoricalValueEncodings categoricalValueEncodings =
      new CategoricalValueEncodings(distinctValues);

  DataDictionary dictionary =
      AppPMMLUtils.buildDataDictionary(buildTestSchema(), categoricalValueEncodings);
  assertEquals(4, dictionary.getNumberOfFields().intValue());
  checkDataField(dictionary.getDataFields().get(0), "foo", null);
  checkDataField(dictionary.getDataFields().get(1), "bar", true);
  checkDataField(dictionary.getDataFields().get(2), "baz", null);
  checkDataField(dictionary.getDataFields().get(3), "bing", false);

  List<Value> dfValues = dictionary.getDataFields().get(1).getValues();
  assertEquals(5, dfValues.size());
  String[] categoricalValues = { "one", "two", "three", "four", "five" };
  for (int i = 0; i < categoricalValues.length; i++) {
    assertEquals(categoricalValues[i], dfValues.get(i).getValue());
  }
}
 
Example #4
Source File: MapHolderParser.java    From jpmml-evaluator with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public VisitorAction visit(DataDictionary dataDictionary){

	if(dataDictionary.hasDataFields()){
		List<DataField> dataFields = dataDictionary.getDataFields();

		for(ListIterator<DataField> it = dataFields.listIterator(); it.hasNext(); ){
			DataField dataField = it.next();

			if(dataField.hasValues()){
				it.set(new RichDataField(dataField));
			}
		}
	}

	return super.visit(dataDictionary);
}
 
Example #5
Source File: AppPMMLUtilsTest.java    From oryx with Apache License 2.0 6 votes vote down vote up
@Test
public void testBuildCategoricalEncoding() {
  List<DataField> dataFields = new ArrayList<>();
  dataFields.add(new DataField(FieldName.create("foo"), OpType.CONTINUOUS, DataType.DOUBLE));
  DataField barField =
      new DataField(FieldName.create("bar"), OpType.CATEGORICAL, DataType.STRING);
  barField.addValues(new Value("b"), new Value("a"));
  dataFields.add(barField);
  DataDictionary dictionary = new DataDictionary(dataFields).setNumberOfFields(dataFields.size());
  CategoricalValueEncodings encodings = AppPMMLUtils.buildCategoricalValueEncodings(dictionary);
  assertEquals(2, encodings.getValueCount(1));
  assertEquals(0, encodings.getValueEncodingMap(1).get("b").intValue());
  assertEquals(1, encodings.getValueEncodingMap(1).get("a").intValue());
  assertEquals("b", encodings.getEncodingValueMap(1).get(0));
  assertEquals("a", encodings.getEncodingValueMap(1).get(1));
  assertEquals(Collections.singletonMap(1, 2), encodings.getCategoryCounts());
}
 
Example #6
Source File: InvalidMarkupInspector.java    From jpmml-evaluator with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public VisitorAction visit(DataDictionary dataDictionary){
	check(new CollectionSize(dataDictionary){

		@Override
		public Integer getSize(){
			return dataDictionary.getNumberOfFields();
		}

		@Override
		public Collection<?> getCollection(){
			return dataDictionary.getDataFields();
		}
	});

	return super.visit(dataDictionary);
}
 
Example #7
Source File: FieldResolver.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
private void declareGlobalFields(PMML pmml, boolean transformations){
	List<Field<?>> scope = this.scopes.get(pmml);

	if(scope != null){
		scope.clear();
	}

	DataDictionary dataDictionary = pmml.getDataDictionary();
	if(dataDictionary != null && dataDictionary.hasDataFields()){
		declareFields(pmml, dataDictionary.getDataFields());
	}

	TransformationDictionary transformationDictionary = pmml.getTransformationDictionary();
	if(transformations && (transformationDictionary != null && transformationDictionary.hasDerivedFields())){
		declareFields(pmml, transformationDictionary.getDerivedFields());
	}
}
 
Example #8
Source File: MarshallerTest.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
@Test
public void marshal() throws Exception {
	PMML pmml = new PMML(Version.PMML_4_4.getVersion(), new Header(), new DataDictionary());

	JAXBContext context = ContextFactory.createContext(new Class[]{org.dmg.pmml.ObjectFactory.class}, null);

	Marshaller marshaller = context.createMarshaller();

	String string;

	try(ByteArrayOutputStream os = new ByteArrayOutputStream()){
		marshaller.marshal(pmml, os);

		string = os.toString("UTF-8");
	}

	assertTrue(string.contains("<PMML xmlns=\"http://www.dmg.org/PMML-4_4\""));
	assertTrue(string.contains(" version=\"4.4\">"));
	assertTrue(string.contains("</PMML>"));
}
 
Example #9
Source File: VersionInspectorTest.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
static
private PMML createPMML(){
	Header header = new Header()
		.setCopyright("ACME Corporation");

	DataDictionary dataDictionary = new DataDictionary();

	PMML pmml = new PMML(Version.PMML_4_4.getVersion(), header, dataDictionary);

	return pmml;
}
 
Example #10
Source File: UnsupportedMarkupInspectorTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 5 votes vote down vote up
@Test
public void inspect(){
	ClusteringModel clusteringModel = new ClusteringModel()
		.setModelClass(ClusteringModel.ModelClass.DISTRIBUTION_BASED)
		.setCenterFields(new CustomCenterFields());

	PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), new DataDictionary())
		.addModels(clusteringModel);

	UnsupportedMarkupInspector inspector = new UnsupportedMarkupInspector();

	try {
		inspector.applyTo(pmml);

		fail();
	} catch(UnsupportedMarkupException ume){
		List<UnsupportedMarkupException> exceptions = inspector.getExceptions();

		assertEquals(2, exceptions.size());
		assertEquals(0, exceptions.indexOf(ume));

		UnsupportedMarkupException exception = exceptions.get(0);

		String message = exception.getMessage();

		assertTrue(message.contains("ClusteringModel@modelClass=distributionBased"));

		exception = exceptions.get(1);

		message = exception.getMessage();

		assertTrue(message.contains("CenterFields"));
		assertTrue(message.contains(CustomCenterFields.class.getName()));
	}
}
 
Example #11
Source File: NodeResolverTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 5 votes vote down vote up
@Test
public void resolve(){
	Node leftChild = new LeafNode()
		.setId("1");

	Node rightChild = new LeafNode()
		.setId("2");

	Node root = new BranchNode(null, True.INSTANCE)
		.setId("0")
		.setDefaultChild(rightChild.getId())
		.addNodes(leftChild, rightChild);

	TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, new MiningSchema(), null)
		.setNode(root);

	PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), new DataDictionary())
		.addModels(treeModel);

	NodeResolver resolver = new NodeResolver();
	resolver.applyTo(pmml);

	assertEquals(rightChild.getId(), root.getDefaultChild());

	treeModel.setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);

	resolver.applyTo(pmml);

	assertSame(rightChild, root.getDefaultChild());
}
 
Example #12
Source File: MarshallerTest.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
@Test
public void marshal() throws Exception {
	PMML pmml = new PMML(Version.PMML_4_4.getVersion(), new Header(), new DataDictionary());

	RegressionModel regressionModel = new RegressionModel()
		.addRegressionTables(new RegressionTable());

	pmml.addModels(regressionModel);

	JAXBContext context = JAXBContextFactory.createContext(new Class[]{org.dmg.pmml.ObjectFactory.class, org.dmg.pmml.regression.ObjectFactory.class}, null);

	Marshaller marshaller = context.createMarshaller();

	String string;

	try(ByteArrayOutputStream os = new ByteArrayOutputStream()){
		marshaller.marshal(pmml, os);

		string = os.toString("UTF-8");
	}

	assertTrue(string.contains("<PMML xmlns=\"http://www.dmg.org/PMML-4_4\""));
	assertTrue(string.contains(" version=\"4.4\">"));
	assertTrue(string.contains("<RegressionModel>"));
	assertTrue(string.contains("</RegressionModel>"));
	assertTrue(string.contains("</PMML>"));
}
 
Example #13
Source File: AppPMMLUtils.java    From oryx with Apache License 2.0 5 votes vote down vote up
public static DataDictionary buildDataDictionary(
    InputSchema schema,
    CategoricalValueEncodings categoricalValueEncodings) {
  List<String> featureNames = schema.getFeatureNames();

  List<DataField> dataFields = new ArrayList<>();
  for (int featureIndex = 0; featureIndex < featureNames.size(); featureIndex++) {
    String featureName = featureNames.get(featureIndex);
    OpType opType;
    DataType dataType;
    if (schema.isNumeric(featureName)) {
      opType = OpType.CONTINUOUS;
      dataType = DataType.DOUBLE;
    } else if (schema.isCategorical(featureName)) {
      opType = OpType.CATEGORICAL;
      dataType = DataType.STRING;
    } else {
      // Don't know
      opType = null;
      dataType = null;
    }
    DataField field = new DataField(FieldName.create(featureName), opType, dataType);
    if (schema.isCategorical(featureName)) {
      Objects.requireNonNull(categoricalValueEncodings);
      categoricalValueEncodings.getEncodingValueMap(featureIndex).entrySet().stream().
          sorted(Comparator.comparing(Map.Entry::getKey)).
          map(Map.Entry::getValue).
          forEach(value -> field.addValues(new Value(value)));
    }
    dataFields.add(field);
  }

  return new DataDictionary(dataFields).setNumberOfFields(dataFields.size());
}
 
Example #14
Source File: ReflectionUtilTest.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
@Test
public void copyState(){
	PMML pmml = new PMML(Version.PMML_4_4.getVersion(), new Header(), new DataDictionary());

	// Initialize a live list instance
	pmml.getModels();

	CustomPMML customPmml = new CustomPMML();

	ReflectionUtil.copyState(pmml, customPmml);

	assertSame(pmml.getVersion(), customPmml.getVersion());
	assertSame(pmml.getHeader(), customPmml.getHeader());
	assertSame(pmml.getDataDictionary(), customPmml.getDataDictionary());

	assertFalse(pmml.hasModels());
	assertFalse(customPmml.hasModels());

	pmml.addModels(new RegressionModel());

	assertTrue(pmml.hasModels());
	assertTrue(customPmml.hasModels());

	assertSame(pmml.getModels(), customPmml.getModels());

	try {
		ReflectionUtil.copyState(customPmml, pmml);

		fail();
	} catch(IllegalArgumentException iae){
		// Ignored
	}
}
 
Example #15
Source File: AppPMMLUtilsTest.java    From oryx with Apache License 2.0 5 votes vote down vote up
@Test
public void testListFeaturesDD() {
  Map<Integer,Collection<String>> distinctValues = new HashMap<>();
  distinctValues.put(1, Arrays.asList("one", "two", "three", "four", "five"));
  CategoricalValueEncodings categoricalValueEncodings =
      new CategoricalValueEncodings(distinctValues);
  DataDictionary dictionary = AppPMMLUtils.buildDataDictionary(
      buildTestSchema(), categoricalValueEncodings);
  List<String> featureNames = AppPMMLUtils.getFeatureNames(dictionary);
  assertEquals(Arrays.asList("foo", "bar", "baz", "bing"), featureNames);
}
 
Example #16
Source File: AppPMMLUtils.java    From oryx with Apache License 2.0 5 votes vote down vote up
/**
 * @param dictionary {@link DataDictionary} from model
 * @return names of features in order
 */
public static List<String> getFeatureNames(DataDictionary dictionary) {
  List<DataField> dataFields = dictionary.getDataFields();
  Preconditions.checkArgument(dataFields != null && !dataFields.isEmpty(),
                              "No fields in DataDictionary");
  return dataFields.stream().map(field -> field.getName().getValue()).collect(Collectors.toList());
}
 
Example #17
Source File: AppPMMLUtils.java    From oryx with Apache License 2.0 5 votes vote down vote up
public static CategoricalValueEncodings buildCategoricalValueEncodings(
    DataDictionary dictionary) {
  Map<Integer,Collection<String>> indexToValues = new HashMap<>();
  List<DataField> dataFields = dictionary.getDataFields();
  for (int featureIndex = 0; featureIndex < dataFields.size(); featureIndex++) {
    DataField field = dataFields.get(featureIndex);
    Collection<Value> values = field.getValues();
    if (values != null && !values.isEmpty()) {
      Collection<String> categoricalValues =
          values.stream().map(v -> v.getValue().toString()).collect(Collectors.toList());
      indexToValues.put(featureIndex, categoricalValues);
    }
  }
  return new CategoricalValueEncodings(indexToValues);
}
 
Example #18
Source File: RDFPMMLUtils.java    From oryx with Apache License 2.0 5 votes vote down vote up
/**
 * Validates that the encoded PMML model received matches expected schema.
 *
 * @param pmml {@link PMML} encoding of a decision forest
 * @param schema expected schema attributes of decision forest
 */
public static void validatePMMLVsSchema(PMML pmml, InputSchema schema) {
  List<Model> models = pmml.getModels();
  Preconditions.checkArgument(models.size() == 1,
                              "Should have exactly one model, but had %s", models.size());

  Model model = models.get(0);
  MiningFunction function = model.getMiningFunction();
  if (schema.isClassification()) {
    Preconditions.checkArgument(function == MiningFunction.CLASSIFICATION,
                                "Expected classification function type but got %s",
                                function);
  } else {
    Preconditions.checkArgument(function == MiningFunction.REGRESSION,
                                "Expected regression function type but got %s",
                                function);
  }

  DataDictionary dictionary = pmml.getDataDictionary();
  Preconditions.checkArgument(
      schema.getFeatureNames().equals(AppPMMLUtils.getFeatureNames(dictionary)),
      "Feature names in schema don't match names in PMML");

  MiningSchema miningSchema = model.getMiningSchema();
  Preconditions.checkArgument(schema.getFeatureNames().equals(
      AppPMMLUtils.getFeatureNames(miningSchema)));

  Integer pmmlIndex = AppPMMLUtils.findTargetIndex(miningSchema);
  if (schema.hasTarget()) {
    int schemaIndex = schema.getTargetFeatureIndex();
    Preconditions.checkArgument(
        pmmlIndex != null && schemaIndex == pmmlIndex,
        "Configured schema expects target at index %s, but PMML has target at index %s",
        schemaIndex, pmmlIndex);
  } else {
    Preconditions.checkArgument(pmmlIndex == null);
  }
}
 
Example #19
Source File: AbstractAppMLlibIT.java    From oryx with Apache License 2.0 5 votes vote down vote up
protected static void checkDataDictionary(InputSchema schema, DataDictionary dataDictionary) {
  assertNotNull(dataDictionary);
  assertEquals("Wrong number of features",
               schema.getNumFeatures(),
               dataDictionary.getNumberOfFields().intValue());
  List<DataField> dataFields = dataDictionary.getDataFields();
  assertEquals(schema.getNumFeatures(), dataFields.size());
  for (DataField dataField : dataFields) {
    String featureName = dataField.getName().getValue();
    if (schema.isNumeric(featureName)) {
      assertEquals("Wrong op type for feature " + featureName,
                   OpType.CONTINUOUS,
                   dataField.getOpType());
      assertEquals("Wrong data type for feature " + featureName,
                   DataType.DOUBLE,
                   dataField.getDataType());
    } else if (schema.isCategorical(featureName)) {
      assertEquals("Wrong op type for feature " + featureName,
                   OpType.CATEGORICAL,
                   dataField.getOpType());
      assertEquals("Wrong data type for feature " + featureName,
                   DataType.STRING,
                   dataField.getDataType());
    } else {
      assertNull(dataField.getOpType());
      assertNull(dataField.getDataType());
    }
  }
}
 
Example #20
Source File: RDFPMMLUtilsTest.java    From oryx with Apache License 2.0 4 votes vote down vote up
private static PMML buildDummyClassificationModel(int numTrees) {
  PMML pmml = PMMLUtils.buildSkeletonPMML();

  List<DataField> dataFields = new ArrayList<>();
  DataField predictor =
      new DataField(FieldName.create("color"), OpType.CATEGORICAL, DataType.STRING);
  predictor.addValues(new Value("yellow"), new Value("red"));
  dataFields.add(predictor);
  DataField target =
      new DataField(FieldName.create("fruit"), OpType.CATEGORICAL, DataType.STRING);
  target.addValues(new Value("banana"), new Value("apple"));
  dataFields.add(target);
  DataDictionary dataDictionary =
      new DataDictionary(dataFields).setNumberOfFields(dataFields.size());
  pmml.setDataDictionary(dataDictionary);

  List<MiningField> miningFields = new ArrayList<>();
  MiningField predictorMF = new MiningField(FieldName.create("color"))
      .setOpType(OpType.CATEGORICAL)
      .setUsageType(MiningField.UsageType.ACTIVE)
      .setImportance(0.5);
  miningFields.add(predictorMF);
  MiningField targetMF = new MiningField(FieldName.create("fruit"))
      .setOpType(OpType.CATEGORICAL)
      .setUsageType(MiningField.UsageType.PREDICTED);
  miningFields.add(targetMF);
  MiningSchema miningSchema = new MiningSchema(miningFields);

  double dummyCount = 2.0;
  Node rootNode =
    new ComplexNode().setId("r").setRecordCount(dummyCount).setPredicate(new True());

  double halfCount = dummyCount / 2;

  Node left = new ComplexNode().setId("r-").setRecordCount(halfCount).setPredicate(new True());
  left.addScoreDistributions(new ScoreDistribution("apple", halfCount));
  Node right = new ComplexNode().setId("r+").setRecordCount(halfCount)
      .setPredicate(new SimpleSetPredicate(FieldName.create("color"),
                                           SimpleSetPredicate.BooleanOperator.IS_NOT_IN,
                                           new Array(Array.Type.STRING, "red")));
  right.addScoreDistributions(new ScoreDistribution("banana", halfCount));

  rootNode.addNodes(right, left);

  TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, miningSchema, rootNode)
      .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT)
      .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);

  if (numTrees > 1) {
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, miningSchema);
    List<Segment> segments = new ArrayList<>();
    for (int i = 0; i < numTrees; i++) {
      segments.add(new Segment()
          .setId(Integer.toString(i))
          .setPredicate(new True())
          .setModel(treeModel)
          .setWeight(1.0));
    }
    miningModel.setSegmentation(
        new Segmentation(Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE, segments));
    pmml.addModels(miningModel);
  } else {
    pmml.addModels(treeModel);
  }

  return pmml;
}
 
Example #21
Source File: ModelManager.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
protected ModelManager(PMML pmml, M model){
	setPMML(Objects.requireNonNull(pmml));
	setModel(Objects.requireNonNull(model));

	DataDictionary dataDictionary = pmml.getDataDictionary();
	if(dataDictionary == null){
		throw new MissingElementException(pmml, PMMLElements.PMML_DATADICTIONARY);
	} // End if

	if(dataDictionary.hasDataFields()){
		this.dataFields = CacheUtil.getValue(dataDictionary, ModelManager.dataFieldCache);
	}

	TransformationDictionary transformationDictionary = pmml.getTransformationDictionary();
	if(transformationDictionary != null && transformationDictionary.hasDerivedFields()){
		this.derivedFields = CacheUtil.getValue(transformationDictionary, ModelManager.derivedFieldCache);
	} // End if

	if(transformationDictionary != null && transformationDictionary.hasDefineFunctions()){
		this.defineFunctions = CacheUtil.getValue(transformationDictionary, ModelManager.defineFunctionCache);
	}

	MiningFunction miningFunction = model.getMiningFunction();
	if(miningFunction == null){
		throw new MissingAttributeException(MissingAttributeException.formatMessage(XPathUtil.formatElement(model.getClass()) + "@miningFunction"), model);
	}

	MiningSchema miningSchema = model.getMiningSchema();
	if(miningSchema == null){
		throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(model.getClass()) + "/" + XPathUtil.formatElement(MiningSchema.class)), model);
	} // End if

	if(miningSchema.hasMiningFields()){
		List<MiningField> miningFields = miningSchema.getMiningFields();

		for(MiningField miningField : miningFields){
			FieldName name = miningField.getName();
			if(name == null){
				throw new MissingAttributeException(miningField, PMMLAttributes.MININGFIELD_NAME);
			}
		}

		this.miningFields = CacheUtil.getValue(miningSchema, ModelManager.miningFieldCache);
	}

	LocalTransformations localTransformations = model.getLocalTransformations();
	if(localTransformations != null && localTransformations.hasDerivedFields()){
		this.localDerivedFields = CacheUtil.getValue(localTransformations, ModelManager.localDerivedFieldCache);
	}

	Targets targets = model.getTargets();
	if(targets != null && targets.hasTargets()){
		this.targets = CacheUtil.getValue(targets, ModelManager.targetCache);
	}

	Output output = model.getOutput();
	if(output != null && output.hasOutputFields()){
		this.outputFields = CacheUtil.getValue(output, ModelManager.outputFieldCache);
		this.resultFeatures = CacheUtil.getValue(output, ModelManager.resultFeaturesCache);
	}
}
 
Example #22
Source File: ModelManager.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
@Override
public Map<FieldName, DataField> load(DataDictionary dataDictionary){
	return IndexableUtil.buildMap(dataDictionary.getDataFields());
}
 
Example #23
Source File: KMeansPMMLUtilsTest.java    From oryx with Apache License 2.0 4 votes vote down vote up
public static PMML buildDummyClusteringModel() {
  PMML pmml = PMMLUtils.buildSkeletonPMML();

  List<DataField> dataFields = new ArrayList<>();
  dataFields.add(new DataField(FieldName.create("x"), OpType.CONTINUOUS, DataType.DOUBLE));
  dataFields.add(new DataField(FieldName.create("y"), OpType.CONTINUOUS, DataType.DOUBLE));
  DataDictionary dataDictionary =
      new DataDictionary(dataFields).setNumberOfFields(dataFields.size());
  pmml.setDataDictionary(dataDictionary);

  List<MiningField> miningFields = new ArrayList<>();
  MiningField xMF = new MiningField(FieldName.create("x"))
      .setOpType(OpType.CONTINUOUS).setUsageType(MiningField.UsageType.ACTIVE);
  miningFields.add(xMF);
  MiningField yMF = new MiningField(FieldName.create("y"))
      .setOpType(OpType.CONTINUOUS).setUsageType(MiningField.UsageType.ACTIVE);
  miningFields.add(yMF);
  MiningSchema miningSchema = new MiningSchema(miningFields);

  List<ClusteringField> clusteringFields = new ArrayList<>();
  clusteringFields.add(new ClusteringField(
      FieldName.create("x")).setCenterField(ClusteringField.CenterField.TRUE));
  clusteringFields.add(new ClusteringField(
      FieldName.create("y")).setCenterField(ClusteringField.CenterField.TRUE));

  List<Cluster> clusters = new ArrayList<>();
  clusters.add(new Cluster().setId("0").setSize(1).setArray(AppPMMLUtils.toArray(1.0, 0.0)));
  clusters.add(new Cluster().setId("1").setSize(2).setArray(AppPMMLUtils.toArray(2.0, -1.0)));
  clusters.add(new Cluster().setId("2").setSize(3).setArray(AppPMMLUtils.toArray(-1.0, 0.0)));

  pmml.addModels(new ClusteringModel(
      MiningFunction.CLUSTERING,
      ClusteringModel.ModelClass.CENTER_BASED,
      clusters.size(),
      miningSchema,
      new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, new SquaredEuclidean()),
      clusteringFields,
      clusters));

  return pmml;
}
 
Example #24
Source File: HasNodeRegistryTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
@Test
public void getPath(){
	Node node1a = new BranchNode();

	Node node2a = new BranchNode();
	Node node2b = new BranchNode();

	node1a.addNodes(node2a, node2b);

	Node node3a = new BranchNode();
	Node node3b = new BranchNode();

	node2a.addNodes(node3a, node3b);

	Node node3c = new LeafNode();
	Node node3d = new LeafNode();

	node2b.addNodes(node3c, node3d);

	PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), new DataDictionary())
		.addModels(new TreeModel(MiningFunction.REGRESSION, new MiningSchema(), node1a));

	HasNodeRegistry hasNodeRegistry = new TreeModelEvaluator(pmml);

	BiMap<Node, String> nodeRegistry = (hasNodeRegistry.getEntityRegistry()).inverse();

	String id1a = nodeRegistry.get(node1a);

	String id2a = nodeRegistry.get(node2a);
	String id2b = nodeRegistry.get(node2b);

	String id3a = nodeRegistry.get(node3a);
	String id3b = nodeRegistry.get(node3b);
	String id3c = nodeRegistry.get(node3c);
	String id3d = nodeRegistry.get(node3d);

	assertEquals(Arrays.asList(node1a), hasNodeRegistry.getPath(id1a));
	assertEquals(Arrays.asList(node1a, node2a), hasNodeRegistry.getPath(id2a));
	assertEquals(Arrays.asList(node1a, node2a, node3a), hasNodeRegistry.getPath(id3a));

	assertEquals(Arrays.asList(node1a), hasNodeRegistry.getPathBetween(id1a, id1a));
	assertEquals(Arrays.asList(node1a, node2a), hasNodeRegistry.getPathBetween(id1a, id2a));
	assertNull(hasNodeRegistry.getPathBetween(id2a, id1a));
	assertEquals(Arrays.asList(node2a, node3a), hasNodeRegistry.getPathBetween(id2a, id3a));
	assertEquals(Arrays.asList(node2a, node3b), hasNodeRegistry.getPathBetween(id2a, id3b));

	assertNull(hasNodeRegistry.getPathBetween(id2a, id2b));
	assertNull(hasNodeRegistry.getPathBetween(id2a, id3c));
	assertNull(hasNodeRegistry.getPathBetween(id2a, id3d));
}
 
Example #25
Source File: GolfingTreeModelExample.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
@Override
public PMML produce(){
	FieldName temperature = FieldName.create("temperature");
	FieldName humidity = FieldName.create("humidity");
	FieldName windy = FieldName.create("windy");
	FieldName outlook = FieldName.create("outlook");
	FieldName whatIdo = FieldName.create("whatIDo");

	Header header = new Header()
		.setCopyright("www.dmg.org")
		.setDescription("A very small binary tree model to show structure.");

	DataDictionary dataDictionary = new DataDictionary()
		.addDataFields(
			new DataField(temperature, OpType.CONTINUOUS, DataType.DOUBLE),
			new DataField(humidity, OpType.CONTINUOUS, DataType.DOUBLE),
			new DataField(windy, OpType.CATEGORICAL, DataType.STRING)
				.addValues(createValues("true", "false")),
			new DataField(outlook, OpType.CATEGORICAL, DataType.STRING)
				.addValues(createValues("sunny", "overcast", "rain")),
			new DataField(whatIdo, OpType.CATEGORICAL, DataType.STRING)
				.addValues(createValues("will play", "may play", "no play"))
		);

	dataDictionary.setNumberOfFields((dataDictionary.getDataFields()).size());

	MiningSchema miningSchema = new MiningSchema()
		.addMiningFields(
			new MiningField(temperature),
			new MiningField(humidity),
			new MiningField(windy),
			new MiningField(outlook),
			new MiningField(whatIdo)
				.setUsageType(MiningField.UsageType.TARGET)
		);

	Node root = new BranchNode("will play", True.INSTANCE);

	// Upper half of the tree
	root.addNodes(
		new BranchNode("will play", new SimplePredicate(outlook, Operator.EQUAL, "sunny"))
			.addNodes(
				new BranchNode("will play",
					createCompoundPredicate(BooleanOperator.AND,
						new SimplePredicate(temperature, Operator.LESS_THAN, "90"),
						new SimplePredicate(temperature, Operator.GREATER_THAN, "50"))
					)
					.addNodes(
						new LeafNode("will play", new SimplePredicate(humidity, Operator.LESS_THAN, "80")),
						new LeafNode("no play", new SimplePredicate(humidity, Operator.GREATER_OR_EQUAL, "80"))
					),
				new LeafNode("no play",
					createCompoundPredicate(BooleanOperator.OR,
						new SimplePredicate(temperature, Operator.GREATER_OR_EQUAL, "90"),
						new SimplePredicate(temperature, Operator.LESS_OR_EQUAL, "50"))
					)
			)
	);

	// Lower half of the tree
	root.addNodes(
		new BranchNode("may play",
			createCompoundPredicate(BooleanOperator.OR,
				new SimplePredicate(outlook, Operator.EQUAL, "overcast"),
				new SimplePredicate(outlook, Operator.EQUAL, "rain"))
			)
			.addNodes(
				new LeafNode("may play",
					createCompoundPredicate(BooleanOperator.AND,
						new SimplePredicate(temperature, Operator.GREATER_THAN, "60"),
						new SimplePredicate(temperature, Operator.LESS_THAN, "100"),
						new SimplePredicate(outlook, Operator.EQUAL, "overcast"),
						new SimplePredicate(humidity, Operator.LESS_THAN, "70"),
						new SimplePredicate(windy, Operator.EQUAL, "false"))
					),
				new LeafNode("no play",
					createCompoundPredicate(BooleanOperator.AND,
						new SimplePredicate(outlook, Operator.EQUAL, "rain"),
						new SimplePredicate(humidity, Operator.LESS_THAN, "70"))
					)
			)
	);

	TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, miningSchema, root)
		.setModelName("golfing");

	PMML pmml = new PMML(Version.PMML_4_4.getVersion(), header, dataDictionary)
		.addModels(treeModel);

	return pmml;
}
 
Example #26
Source File: RDFPMMLUtilsTest.java    From oryx with Apache License 2.0 4 votes vote down vote up
public static PMML buildDummyRegressionModel() {
  PMML pmml = PMMLUtils.buildSkeletonPMML();

  List<DataField> dataFields = new ArrayList<>();
  dataFields.add(new DataField(FieldName.create("foo"), OpType.CONTINUOUS, DataType.DOUBLE));
  dataFields.add(new DataField(FieldName.create("bar"), OpType.CONTINUOUS, DataType.DOUBLE));
  DataDictionary dataDictionary =
      new DataDictionary(dataFields).setNumberOfFields(dataFields.size());
  pmml.setDataDictionary(dataDictionary);

  List<MiningField> miningFields = new ArrayList<>();
  MiningField predictorMF = new MiningField(FieldName.create("foo"))
      .setOpType(OpType.CONTINUOUS)
      .setUsageType(MiningField.UsageType.ACTIVE)
      .setImportance(0.5);
  miningFields.add(predictorMF);
  MiningField targetMF = new MiningField(FieldName.create("bar"))
      .setOpType(OpType.CONTINUOUS)
      .setUsageType(MiningField.UsageType.PREDICTED);
  miningFields.add(targetMF);
  MiningSchema miningSchema = new MiningSchema(miningFields);

  double dummyCount = 2.0;
  Node rootNode =
      new ComplexNode().setId("r").setRecordCount(dummyCount).setPredicate(new True());

  double halfCount = dummyCount / 2;

  Node left = new ComplexNode()
      .setId("r-")
      .setRecordCount(halfCount)
      .setPredicate(new True())
      .setScore("-2.0");
  Node right = new ComplexNode().setId("r+").setRecordCount(halfCount)
      .setPredicate(new SimplePredicate(FieldName.create("foo"),
                                        SimplePredicate.Operator.GREATER_THAN,
                                        "3.14"))
      .setScore("2.0");

  rootNode.addNodes(right, left);

  TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, miningSchema, rootNode)
      .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT)
      .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD)
      .setMiningSchema(miningSchema);

  pmml.addModels(treeModel);

  return pmml;
}
 
Example #27
Source File: ValueParserTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
@Test
public void parseRegressionModel(){
	Value falseValue = new Value("false");
	Value trueValue = new Value("true");
	Value invalidValue = new Value("N/A");

	DataField dataField = new DataField(FieldName.create("x1"), OpType.CATEGORICAL, DataType.STRING)
		.addValues(falseValue, trueValue, invalidValue);

	DataDictionary dataDictionary = new DataDictionary()
		.addDataFields(dataField);

	CategoricalPredictor falseTerm = new CategoricalPredictor(dataField.getName(), "false", -1d);
	CategoricalPredictor trueTerm = new CategoricalPredictor(dataField.getName(), "true", 1d);

	RegressionTable regressionTable = new RegressionTable()
		.addCategoricalPredictors(falseTerm, trueTerm);

	MiningField miningField = new MiningField(dataField.getName())
		.setMissingValueReplacement("false")
		.setInvalidValueReplacement("N/A");

	MiningSchema miningSchema = new MiningSchema()
		.addMiningFields(miningField);

	RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, miningSchema, null)
		.addRegressionTables(regressionTable);

	PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), dataDictionary)
		.addModels(regressionModel);

	List<DataField> dataFields = dataDictionary.getDataFields();

	ValueParser parser = new ValueParser(ValueParser.Mode.STRICT);
	parser.applyTo(pmml);

	dataField = dataFields.get(0);

	assertEquals("false", falseValue.getValue());
	assertEquals("true", trueValue.getValue());
	assertEquals("N/A", invalidValue.getValue());

	assertEquals("false", falseTerm.getValue());
	assertEquals("true", trueTerm.getValue());

	assertEquals("false", miningField.getMissingValueReplacement());
	assertEquals("N/A", miningField.getInvalidValueReplacement());

	dataField.setDataType(DataType.BOOLEAN);

	parser.applyTo(pmml);

	assertEquals(Boolean.FALSE, falseValue.getValue());
	assertEquals(Boolean.TRUE, trueValue.getValue());
	assertEquals("N/A", invalidValue.getValue());

	assertEquals(Boolean.FALSE, falseTerm.getValue());
	assertEquals(Boolean.TRUE, trueTerm.getValue());

	assertEquals(Boolean.FALSE, miningField.getMissingValueReplacement());
	assertEquals("N/A", miningField.getInvalidValueReplacement());
}
 
Example #28
Source File: ModelManagerFactoryTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
@Test
public void newModelManager(){
	ModelManagerFactory<ModelManager<?>> modelManagerFactory = new ModelManagerFactory<ModelManager<?>>(null){

		@Override
		public List<Class<? extends ModelManager<?>>> getServiceProviderClasses(Class<? extends Model> modelClazz){
			return Arrays.asList(RegressorManager.class, ClassifierManager.class);
		}
	};

	TreeModel treeModel = new TreeModel()
		.setMiningFunction(null)
		.setMiningSchema(new MiningSchema());

	PMML pmml = new PMML()
		.setHeader(new Header())
		.setDataDictionary(new DataDictionary())
		.addModels(treeModel);

	ModelManager<?> modelManager;

	try {
		modelManager = modelManagerFactory.newModelManager(pmml, treeModel);

		fail();
	} catch(InvalidMarkupException ime){
		// Ignored
	}

	treeModel.setMiningFunction(MiningFunction.REGRESSION);

	modelManager = modelManagerFactory.newModelManager(pmml, treeModel);

	assertTrue(modelManager instanceof RegressorManager);

	treeModel.setMiningFunction(MiningFunction.CLASSIFICATION);

	modelManager = modelManagerFactory.newModelManager(pmml, treeModel);

	assertTrue(modelManager instanceof ClassifierManager);
}
 
Example #29
Source File: RDFUpdate.java    From oryx with Apache License 2.0 4 votes vote down vote up
private PMML rdfModelToPMML(RandomForestModel rfModel,
                            CategoricalValueEncodings categoricalValueEncodings,
                            int maxDepth,
                            int maxSplitCandidates,
                            String impurity,
                            List<? extends IntLongMap> nodeIDCounts,
                            IntLongMap predictorIndexCounts) {

  boolean classificationTask = rfModel.algo().equals(Algo.Classification());
  Preconditions.checkState(classificationTask == inputSchema.isClassification());

  DecisionTreeModel[] trees = rfModel.trees();

  Model model;
  if (trees.length == 1) {
    model = toTreeModel(trees[0], categoricalValueEncodings, nodeIDCounts.get(0));
  } else {
    MiningModel miningModel = new MiningModel();
    model = miningModel;
    Segmentation.MultipleModelMethod multipleModelMethodType = classificationTask ?
        Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE :
        Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE;
    List<Segment> segments = new ArrayList<>(trees.length);
    for (int treeID = 0; treeID < trees.length; treeID++) {
      TreeModel treeModel =
          toTreeModel(trees[treeID], categoricalValueEncodings, nodeIDCounts.get(treeID));
      segments.add(new Segment()
           .setId(Integer.toString(treeID))
           .setPredicate(new True())
           .setModel(treeModel)
           .setWeight(1.0)); // No weights in MLlib impl now
    }
    miningModel.setSegmentation(new Segmentation(multipleModelMethodType, segments));
  }

  model.setMiningFunction(classificationTask ?
                          MiningFunction.CLASSIFICATION :
                          MiningFunction.REGRESSION);

  double[] importances = countsToImportances(predictorIndexCounts);
  model.setMiningSchema(AppPMMLUtils.buildMiningSchema(inputSchema, importances));
  DataDictionary dictionary =
      AppPMMLUtils.buildDataDictionary(inputSchema, categoricalValueEncodings);

  PMML pmml = PMMLUtils.buildSkeletonPMML();
  pmml.setDataDictionary(dictionary);
  pmml.addModels(model);

  AppPMMLUtils.addExtension(pmml, "maxDepth", maxDepth);
  AppPMMLUtils.addExtension(pmml, "maxSplitCandidates", maxSplitCandidates);
  AppPMMLUtils.addExtension(pmml, "impurity", impurity);

  return pmml;
}
 
Example #30
Source File: InvalidMarkupInspectorTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 3 votes vote down vote up
@Test
public void inspect() throws Exception {
	DataDictionary dataDictionary = new DataDictionary()
		.setNumberOfFields(1);

	Field field = ReflectionUtil.getField(DataDictionary.class, "dataFields");

	assertNull(ReflectionUtil.getFieldValue(field, dataDictionary));

	List<DataField> dataFields = dataDictionary.getDataFields();
	assertEquals(0, dataFields.size());

	assertNotNull(ReflectionUtil.getFieldValue(field, dataDictionary));

	PMML pmml = new PMML(null, null, dataDictionary);

	InvalidMarkupInspector inspector = new InvalidMarkupInspector();

	try {
		inspector.applyTo(pmml);

		fail();
	} catch(InvalidMarkupException ime){
		List<InvalidMarkupException> exceptions = inspector.getExceptions();

		String[] features = {"PMML@version", "PMML/Header", "DataDictionary", "DataDictionary/DataField"};

		assertEquals(features.length, exceptions.size());
		assertEquals(0, exceptions.indexOf(ime));

		for(int i = 0; i < exceptions.size(); i++){
			InvalidMarkupException exception = exceptions.get(i);

			String message = exception.getMessage();

			assertTrue(message.contains(features[i]));
		}
	}
}