biz.k11i.xgboost.Predictor Java Examples
The following examples show how to use
biz.k11i.xgboost.Predictor.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: XGBoostOnlinePredictUDTF.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Override public void process(Object[] args) throws HiveException { if (mapToModel == null) { this.mapToModel = new HashMap<String, Predictor>(); } if (args[1] == null) {// features is null return; } String modelId = PrimitiveObjectInspectorUtils.getString(nonNullArgument(args, 2), modelIdOI); Predictor model = mapToModel.get(modelId); if (model == null) { Text arg3 = modelOI.getPrimitiveWritableObject(nonNullArgument(args, 3)); model = XGBoostUtils.loadPredictor(arg3); mapToModel.put(modelId, model); } Writable rowId = HiveUtils.copyToWritable(nonNullArgument(args, 0), rowIdOI); FVec features = denseFeatures ? parseDenseFeatures(args[1]) : parseSparseFeatures(featureListOI.getList(args[1])); predictAndForward(model, rowId, features); }
Example #2
Source File: XGBoostEvidenceFilterUnitTest.java From gatk with BSD 3-Clause "New" or "Revised" License | 6 votes |
@Test(groups = "sv") protected void testLocalXGBoostClassifierSpark() { final Predictor localPredictor = XGBoostEvidenceFilter.loadPredictor(localClassifierModelFile); // get spark ctx final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext(); // parallelize classifierAccuracyData to RDD JavaRDD<FVec> testFeaturesRdd = ctx.parallelize(Arrays.asList(classifierAccuracyData.features)); // predict in parallel JavaDoubleRDD predictedProbabilityRdd = testFeaturesRdd.mapToDouble(f -> localPredictor.predictSingle(f, false, 0)); // pull back to local array final double[] predictedProbabilitySpark = predictedProbabilityRdd.collect() .stream().mapToDouble(Double::doubleValue).toArray(); // check probabilities from spark are identical to serial assertArrayEquals(predictedProbabilitySpark, predictedProbabilitySerial, 0.0, "Probabilities predicted in spark context differ from serial" ); }
Example #3
Source File: Example.java From xgboost-predictor-java with Apache License 2.0 | 6 votes |
/** * Predicts probability and calculate its logarithmic loss using {@link Predictor#predict(FVec)}. * * @param predictor Predictor * @param data test data */ static void predictAndLogLoss(Predictor predictor, List<SimpleEntry<Integer, FVec>> data) { double sum = 0; for (SimpleEntry<Integer, FVec> pair : data) { double[] predicted = predictor.predict(pair.getValue()); double predValue = Math.min(Math.max(predicted[0], 1e-15), 1 - 1e-15); int actual = pair.getKey(); sum = actual * Math.log(predValue) + (1 - actual) * Math.log(1 - predValue); } double logLoss = -sum / data.size(); System.out.println("Logloss: " + logLoss); }
Example #4
Source File: XGBoostUtils.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Nonnull public static Predictor loadPredictor(@Nonnull final Text model) throws HiveException { try { byte[] b = IOUtils.fromCompressedText(model.getBytes(), model.getLength()); return new Predictor(new FastByteArrayInputStream(b)); } catch (Throwable e) { throw new HiveException("Failed to create a predictor", e); } }
Example #5
Source File: XGBoostEvidenceFilterUnitTest.java From gatk with BSD 3-Clause "New" or "Revised" License | 5 votes |
@Test(groups = "sv") protected void testResourceXGBoostClassifier() { // load classifier from resource final Predictor resourcePredictor = XGBoostEvidenceFilter.loadPredictor(null); final double[] predictedProbabilityResource = predictProbability(resourcePredictor, classifierAccuracyData.features); // check that predictions from resource are identical to local assertArrayEquals(predictedProbabilityResource, predictedProbabilitySerial, 0.0, "Predictions via loading predictor from resource is not identical to local file" ); }
Example #6
Source File: XGBoostEvidenceFilter.java From gatk with BSD 3-Clause "New" or "Revised" License | 5 votes |
public static Predictor loadPredictor(final String modelFileLocation) { ObjFunction.useFastMathExp(USE_FAST_MATH_EXP); try(final InputStream inputStream = modelFileLocation == null ? resourcePathToInputStream(DEFAULT_PREDICTOR_RESOURCE_PATH) : BucketUtils.openFile(modelFileLocation)) { return new Predictor(inputStream); } catch(Exception e) { throw new GATKException( "Unable to load predictor from classifier file " + (modelFileLocation == null ? DEFAULT_PREDICTOR_RESOURCE_PATH : modelFileLocation) + ": " + e.getMessage() ); } }
Example #7
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
protected void verify( PredictionModel model, TestHelper.TestData _testData, TestHelper.Expectation _expectedData, PredictionTask predictionTask) throws IOException { String context = String.format("[model: %s, test: %s, expected: %s, task: %s]", model.path, _testData.path(), _expectedData.path(), predictionTask.name); System.out.println(context); Predictor predictor = model.load(); List<FVec> testDataList = _testData.load(); List<double[]> expectedDataList = _expectedData.load(); for (int i = 0; i < testDataList.size(); i++) { double[] predicted = predictionTask.predict(predictor, testDataList.get(i)); assertThat( String.format("result array length: %s #%d", context, i), predicted.length, is(expectedDataList.get(i).length)); for (int j = 0; j < predicted.length; j++) { assertThat( String.format("prediction value: %s #%d[%d]", context, i, j), predicted[j], closeTo(expectedDataList.get(i)[j], 1e-5)); } } }
Example #8
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
public static PredictionTask predictLeafWithNTree(final int ntree_limit) { return new PredictionTask("leaf_ntree") { @Override double[] predict(Predictor predictor, FVec feat) { return toDoubleArray(predictor.predictLeaf(feat, ntree_limit)); } }; }
Example #9
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
public static PredictionTask predictLeaf() { return new PredictionTask("leaf") { @Override double[] predict(Predictor predictor, FVec feat) { return toDoubleArray(predictor.predictLeaf(feat)); } }; }
Example #10
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
public static PredictionTask predictSingle() { return new PredictionTask("predict_single", "predict") { @Override double[] predict(Predictor predictor, FVec feat) { return new double[]{predictor.predictSingle(feat)}; } }; }
Example #11
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
public static PredictionTask predictMargin() { return new PredictionTask("margin") { @Override double[] predict(Predictor predictor, FVec feat) { return predictor.predict(feat, true); } }; }
Example #12
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
public static PredictionTask predictWithExcessiveNTreeLimit() { return new PredictionTask("predict_excessive_ntree", "predict") { @Override double[] predict(Predictor predictor, FVec feat) { return predictor.predict(feat, false, 1000); } }; }
Example #13
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
public static PredictionTask predictWithNTreeLimit(final int ntree_limit) { return new PredictionTask("predict_ntree") { @Override double[] predict(Predictor predictor, FVec feat) { return predictor.predict(feat, false, ntree_limit); } }; }
Example #14
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
public static PredictionTask predict() { return new PredictionTask("predict") { @Override double[] predict(Predictor predictor, FVec feat) { return predictor.predict(feat); } }; }
Example #15
Source File: Example.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
/** * Predicts leaf index of each tree. * * @param predictor Predictor * @param data test data */ static void predictLeafIndex(Predictor predictor, List<SimpleEntry<Integer, FVec>> data) { int count = 0; for (SimpleEntry<Integer, FVec> pair : data) { int[] leafIndexes = predictor.predictLeaf(pair.getValue()); System.out.printf("leafIndexes[%d]: %s%s", count++, Arrays.toString(leafIndexes), System.lineSeparator()); } }
Example #16
Source File: Example.java From xgboost-predictor-java with Apache License 2.0 | 5 votes |
public static void main(String[] args) throws IOException { List<SimpleEntry<Integer, FVec>> data = loadData(); Predictor predictor = new Predictor(TestHelper.getResourceAsStream("model/gbtree/v47/binary-logistic.model")); predictAndLogLoss(predictor, data); predictLeafIndex(predictor, data); }
Example #17
Source File: XGBoostOnlinePredictUDTF.java From incubator-hivemall with Apache License 2.0 | 5 votes |
private void predictAndForward(@Nonnull final Predictor model, @Nonnull final Writable rowId, @Nonnull final FVec features) throws HiveException { double[] predicted = model.predict(features); // predicted[0] has // - probability ("binary:logistic") // - class label ("multi:softmax") forwardPredicted(rowId, predicted); }
Example #18
Source File: XGBoostClassifier.java From Myna with Apache License 2.0 | 5 votes |
public XGBoostClassifier(Context ctx) { try { InputStream is = ctx.getAssets().open("rhar.model"); predictor = new Predictor(is); is.close(); } catch (Throwable t) { t.printStackTrace(); } }
Example #19
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | 4 votes |
Predictor load() throws IOException { try (InputStream stream = TestHelper.getResourceAsStream(path)) { return new Predictor(stream, configuration); } }
Example #20
Source File: XGBoostEvidenceFilterUnitTest.java From gatk with BSD 3-Clause "New" or "Revised" License | 4 votes |
private static double[] predictProbability(final Predictor predictor, final FVec[] testFeatures) { return Arrays.stream(testFeatures).mapToDouble( features -> predictor.predictSingle(features, false, 0) ).toArray(); }
Example #21
Source File: PredictionTestBase.java From xgboost-predictor-java with Apache License 2.0 | votes |
abstract double[] predict(Predictor predictor, FVec feat);