org.apache.hadoop.hive.ql.udf.generic.Collector Java Examples
The following examples show how to use
org.apache.hadoop.hive.ql.udf.generic.Collector.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: PassiveAggressiveRegressionUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testPA1() throws HiveException { PassiveAggressiveRegressionUDTF udtf = new PassiveAggressiveRegressionUDTF(); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaFloatObjectInspector}); udtf.setCollector(new Collector() { public void collect(Object input) throws HiveException { // noop } }); udtf.process(new Object[] {Arrays.asList("1:-2", "2:-1"), 1.1f}); udtf.process(new Object[] {Arrays.asList("3:-2", "1:-1"), -1.3f}); byte[] serialized = TestUtils.serializeObjectByKryo(udtf); TestUtils.deserializeObjectByKryo(serialized, PassiveAggressiveRegressionUDTF.class); udtf.close(); }
Example #2
Source File: GenerateSeriesUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testTwoIntArgs() throws HiveException { GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); udtf.initialize( new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.writableIntObjectInspector}); final List<IntWritable> actual = new ArrayList<>(); udtf.setCollector(new Collector() { @Override public void collect(Object args) throws HiveException { Object[] row = (Object[]) args; IntWritable row0 = (IntWritable) row[0]; actual.add(new IntWritable(row0.get())); } }); udtf.process(new Object[] {1, new IntWritable(3)}); List<IntWritable> expected = Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3)); Assert.assertEquals(expected, actual); }
Example #3
Source File: LDAUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testSingleRow() throws HiveException { LDAUDTF udtf = new LDAUDTF(); final int numTopics = 2; ObjectInspector[] argOIs = new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics " + numTopics)}; udtf.initialize(argOIs); String[] doc1 = new String[] {"1", "2", "3"}; udtf.process(new Object[] {Arrays.asList(doc1)}); final MutableInt cnt = new MutableInt(0); udtf.setCollector(new Collector() { @Override public void collect(Object arg0) throws HiveException { cnt.addValue(1); } }); udtf.close(); Assert.assertEquals(doc1.length * numTopics, cnt.getValue()); }
Example #4
Source File: PLSAUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testSingleRow() throws HiveException { PLSAUDTF udtf = new PLSAUDTF(); final int numTopics = 2; ObjectInspector[] argOIs = new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics " + numTopics)}; udtf.initialize(argOIs); String[] doc1 = new String[] {"1", "2", "3"}; udtf.process(new Object[] {Arrays.asList(doc1)}); final MutableInt cnt = new MutableInt(0); udtf.setCollector(new Collector() { @Override public void collect(Object arg0) throws HiveException { cnt.addValue(1); } }); udtf.close(); Assert.assertEquals(doc1.length * numTopics, cnt.getValue()); }
Example #5
Source File: GenerateSeriesUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testSerialization() throws HiveException { GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); udtf.initialize( new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.writableIntObjectInspector}); udtf.setCollector(new Collector() { @Override public void collect(Object args) throws HiveException {} }); udtf.process(new Object[] {1, new IntWritable(3)}); byte[] serialized = TestUtils.serializeObjectByKryo(udtf); TestUtils.deserializeObjectByKryo(serialized, GenerateSeriesUDTF.class); }
Example #6
Source File: GenerateSeriesUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testNegativeStepLong() throws HiveException { GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); udtf.initialize( new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector, PrimitiveObjectInspectorFactory.writableIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector}); final List<LongWritable> actual = new ArrayList<>(); udtf.setCollector(new Collector() { @Override public void collect(Object args) throws HiveException { Object[] row = (Object[]) args; LongWritable row0 = (LongWritable) row[0]; actual.add(new LongWritable(row0.get())); } }); udtf.process(new Object[] {5L, new IntWritable(1), -2}); List<LongWritable> expected = Arrays.asList(new LongWritable(5), new LongWritable(3), new LongWritable(1)); Assert.assertEquals(expected, actual); }
Example #7
Source File: GenerateSeriesUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testNegativeStepInt() throws HiveException { GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); udtf.initialize( new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.writableIntObjectInspector, PrimitiveObjectInspectorFactory.javaLongObjectInspector}); final List<IntWritable> actual = new ArrayList<>(); udtf.setCollector(new Collector() { @Override public void collect(Object args) throws HiveException { Object[] row = (Object[]) args; IntWritable row0 = (IntWritable) row[0]; actual.add(new IntWritable(row0.get())); } }); udtf.process(new Object[] {5, new IntWritable(1), -2L}); List<IntWritable> expected = Arrays.asList(new IntWritable(5), new IntWritable(3), new IntWritable(1)); Assert.assertEquals(expected, actual); }
Example #8
Source File: GenerateSeriesUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testThreeLongArgs() throws HiveException { GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); udtf.initialize( new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector, PrimitiveObjectInspectorFactory.writableLongObjectInspector, PrimitiveObjectInspectorFactory.javaLongObjectInspector}); final List<LongWritable> actual = new ArrayList<>(); udtf.setCollector(new Collector() { @Override public void collect(Object args) throws HiveException { Object[] row = (Object[]) args; LongWritable row0 = (LongWritable) row[0]; actual.add(new LongWritable(row0.get())); } }); udtf.process(new Object[] {1L, new LongWritable(7), 3L}); List<LongWritable> expected = Arrays.asList(new LongWritable(1), new LongWritable(4), new LongWritable(7)); Assert.assertEquals(expected, actual); }
Example #9
Source File: GenerateSeriesUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testThreeIntArgs() throws HiveException { GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); udtf.initialize( new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.writableIntObjectInspector, PrimitiveObjectInspectorFactory.javaLongObjectInspector}); final List<IntWritable> actual = new ArrayList<>(); udtf.setCollector(new Collector() { @Override public void collect(Object args) throws HiveException { Object[] row = (Object[]) args; IntWritable row0 = (IntWritable) row[0]; actual.add(new IntWritable(row0.get())); } }); udtf.process(new Object[] {1, new IntWritable(7), 3L}); List<IntWritable> expected = Arrays.asList(new IntWritable(1), new IntWritable(4), new IntWritable(7)); Assert.assertEquals(expected, actual); }
Example #10
Source File: GenerateSeriesUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testTwoLongArgs() throws HiveException { GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); udtf.initialize( new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.writableLongObjectInspector}); final List<LongWritable> actual = new ArrayList<>(); udtf.setCollector(new Collector() { @Override public void collect(Object args) throws HiveException { Object[] row = (Object[]) args; LongWritable row0 = (LongWritable) row[0]; actual.add(new LongWritable(row0.get())); } }); udtf.process(new Object[] {1, new LongWritable(3)}); List<LongWritable> expected = Arrays.asList(new LongWritable(1), new LongWritable(2), new LongWritable(3)); Assert.assertEquals(expected, actual); }
Example #11
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 6 votes |
@Test public void testSparseRandomForestClassifier() throws HiveException { RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector}); udtf.process(new Object[] {new String[] {"1:1.0", "4:1.0", "7:1.0", "12:1.0"}, 1}); // 0 udtf.process(new Object[] {new String[] {"2:1.0", "4:1.0", "5:1.0", "11:1.0"}, 1}); // 1 udtf.process(new Object[] { new String[] {"1:1.0", "4:1.0", "7:1.0", "113:1.0", "497:1.0", "635:1.0"}, 0}); // 2 udtf.process(new Object[] { new String[] {"1:1.0", "4:1.0", "5:1.0", "7:1.0", "10:1.0", "14:1.0"}, 1}); // 3 udtf.process(new Object[] {new String[] {"1:1.0", "2:1.0", "4:1.0", "7:1.0", "8:1.0"}, 1}); // 4 udtf.process(new Object[] {new String[] {"13:1.0", "18:1.0", "25:1.0", "27:1.0", "65:1.0", "116:1.0", "200:1.0", "468:1.0", "585:1.0", "715:1.0"}, 0}); udtf.setCollector(new Collector() { @Override public void collect(Object input) throws HiveException {} }); udtf.close(); }
Example #12
Source File: TestUtils.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@SuppressWarnings("deprecation") public static <T extends GenericUDTF> void testGenericUDTFSerialization(@Nonnull Class<T> clazz, @Nonnull ObjectInspector[] ois, @Nonnull Object[][] rows) throws HiveException { final T udtf; try { udtf = clazz.newInstance(); } catch (InstantiationException | IllegalAccessException e) { throw new HiveException(e); } udtf.initialize(ois); // serialization after initialization byte[] serialized = serializeObjectByKryo(udtf); deserializeObjectByKryo(serialized, clazz); udtf.setCollector(new Collector() { public void collect(Object input) throws HiveException { // noop } }); for (Object[] row : rows) { udtf.process(row); } // serialization after processing row serialized = serializeObjectByKryo(udtf); TestUtils.deserializeObjectByKryo(serialized, clazz); udtf.close(); }
Example #13
Source File: ConditionalEmitUDTFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void test() throws HiveException { ConditionalEmitUDTF udtf = new ConditionalEmitUDTF(); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaBooleanObjectInspector), ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector),}); final List<Object> actual = new ArrayList<>(); udtf.setCollector(new Collector() { @Override public void collect(Object input) throws HiveException { Object[] forwardObj = (Object[]) input; Assert.assertEquals(1, forwardObj.length); actual.add(forwardObj[0]); } }); udtf.process( new Object[] {Arrays.asList(true, false, true), Arrays.asList("one", "two", "three")}); Assert.assertEquals(Arrays.asList("one", "three"), actual); actual.clear(); udtf.process( new Object[] {Arrays.asList(true, true, false), Arrays.asList("one", "two", "three")}); Assert.assertEquals(Arrays.asList("one", "two"), actual); udtf.close(); }
Example #14
Source File: MovingAverageUDTFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void test() throws HiveException { MovingAverageUDTF udtf = new MovingAverageUDTF(); ObjectInspector argOI0 = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; ObjectInspector argOI1 = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaIntObjectInspector, 3); final List<Double> results = new ArrayList<>(); udtf.initialize(new ObjectInspector[] {argOI0, argOI1}); udtf.setCollector(new Collector() { @Override public void collect(Object input) throws HiveException { Object[] objs = (Object[]) input; Assert.assertEquals(1, objs.length); Assert.assertTrue(objs[0] instanceof DoubleWritable); double x = ((DoubleWritable) objs[0]).get(); results.add(x); } }); udtf.process(new Object[] {1.f, null}); udtf.process(new Object[] {2.f, null}); udtf.process(new Object[] {3.f, null}); udtf.process(new Object[] {4.f, null}); udtf.process(new Object[] {5.f, null}); udtf.process(new Object[] {6.f, null}); udtf.process(new Object[] {7.f, null}); Assert.assertEquals(Arrays.asList(1.d, 1.5d, 2.d, 3.d, 4.d, 5.d, 6.d), results); }
Example #15
Source File: GenerateSeriesUDTFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void testTwoConstArgs() throws HiveException { GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); udtf.initialize(new ObjectInspector[] { PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( TypeInfoFactory.intTypeInfo, new IntWritable(1)), PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( TypeInfoFactory.intTypeInfo, new IntWritable(3))}); final List<IntWritable> actual = new ArrayList<>(); udtf.setCollector(new Collector() { @Override public void collect(Object args) throws HiveException { Object[] row = (Object[]) args; IntWritable row0 = (IntWritable) row[0]; actual.add(new IntWritable(row0.get())); } }); udtf.process(new Object[] {new IntWritable(1), new IntWritable(3)}); List<IntWritable> expected = Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3)); Assert.assertEquals(expected, actual); }
Example #16
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@Test public void testSparseRandomForestClassifierL2Normalized() throws HiveException { RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector}); udtf.process(new Object[] {new String[] {"1:0.5", "4:0.5", "7:0.5", "12:0.5"}, 1}); // 0 udtf.process(new Object[] {new String[] {"2:0.5", "4:0.5", "5:0.5", "11:0.5"}, 1}); // 1 udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "7:0.40824828", "113:0.40824828", "497:0.40824828", "635:0.40824828"}, 0}); // 2 udtf.process(new Object[] {new String[] {"1:0.40824828", "4:0.40824828", "5:0.40824828", "7:0.40824828", "10:0.40824828", "14:0.40824828"}, 1}); // 3 udtf.process(new Object[] {new String[] {"1:0.4472136", "2:0.4472136", "4:0.4472136", "7:0.4472136", "8:0.4472136"}, 1}); // 4 udtf.process(new Object[] {new String[] {"13:0.31622776", "18:0.31622776", "25:0.31622776", "27:0.31622776", "65:0.31622776", "116:0.31622776", "200:0.31622776", "468:0.31622776", "585:0.31622776", "715:0.31622776"}, 0}); // 5 udtf.setCollector(new Collector() { @Override public void collect(Object input) throws HiveException {} }); udtf.close(); }
Example #17
Source File: GeneralClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI, @Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception { int y = 0; GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; ListObjectInspector featureListOI = ObjectInspectorFactory.getStandardListObjectInspector(featureOI); udtf.initialize(new ObjectInspector[] {featureListOI, valueOI}); final List<Object> modelFeatures = new ArrayList<Object>(); udtf.setCollector(new Collector() { @Override public void collect(Object input) throws HiveException { Object[] forwardMapObj = (Object[]) input; modelFeatures.add(forwardMapObj[0]); } }); udtf.process(new Object[] {x, y}); udtf.close(); Assert.assertFalse(modelFeatures.isEmpty()); for (Object modelFeature : modelFeatures) { Assert.assertEquals("All model features must have same type", modelFeatureClass, modelFeature.getClass()); } }
Example #18
Source File: TestUtils.java From incubator-hivemall with Apache License 2.0 | 5 votes |
@SuppressWarnings("deprecation") public static <T extends GenericUDTF> void testGenericUDTFSerialization(@Nonnull Class<T> clazz, @Nonnull ObjectInspector[] ois, @Nonnull Object[][] rows) throws HiveException { final T udtf; try { udtf = clazz.newInstance(); } catch (InstantiationException | IllegalAccessException e) { throw new HiveException(e); } udtf.initialize(ois); // serialization after initialization byte[] serialized = serializeObjectByKryo(udtf); deserializeObjectByKryo(serialized, clazz); udtf.setCollector(new Collector() { public void collect(Object input) throws HiveException { // noop } }); for (Object[] row : rows) { udtf.process(row); } // serialization after processing row serialized = serializeObjectByKryo(udtf); TestUtils.deserializeObjectByKryo(serialized, clazz); udtf.close(); }
Example #19
Source File: GeneralRegressorUDTFTest.java From incubator-hivemall with Apache License 2.0 | 5 votes |
private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI, @Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception { float y = 1.f; GeneralRegressorUDTF udtf = new GeneralRegressorUDTF(); ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; ListObjectInspector featureListOI = ObjectInspectorFactory.getStandardListObjectInspector(featureOI); udtf.initialize(new ObjectInspector[] {featureListOI, valueOI}); final List<Object> modelFeatures = new ArrayList<Object>(); udtf.setCollector(new Collector() { @Override public void collect(Object input) throws HiveException { Object[] forwardMapObj = (Object[]) input; modelFeatures.add(forwardMapObj[0]); } }); udtf.process(new Object[] {x, y}); udtf.close(); Assert.assertFalse(modelFeatures.isEmpty()); for (Object modelFeature : modelFeatures) { Assert.assertEquals("All model features must have same type", modelFeatureClass, modelFeature.getClass()); } }
Example #20
Source File: HiveGenericUDTF.java From flink with Apache License 2.0 | 4 votes |
@VisibleForTesting protected final void setCollector(Collector collector) { function.setCollector(collector); }
Example #21
Source File: HiveGenericUDTF.java From flink with Apache License 2.0 | 4 votes |
@VisibleForTesting protected final void setCollector(Collector collector) { function.setCollector(collector); }
Example #22
Source File: RandomForestRegressionUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
private static RegressionTree.Node getRegressionTreeFromSparseInput() throws IOException, ParseException, HiveException { String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"}; double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323}, {259.426, 232.5, 145.6, 108.632, 1948, 61.122}, {258.054, 368.2, 161.6, 109.773, 1949, 60.171}, {284.599, 335.1, 165.0, 110.929, 1950, 61.187}, {328.975, 209.9, 309.9, 112.075, 1951, 63.221}, {346.999, 193.2, 359.4, 113.270, 1952, 63.639}, {365.385, 187.0, 354.7, 115.094, 1953, 64.989}, {363.112, 357.8, 335.0, 116.219, 1954, 63.761}, {397.469, 290.4, 304.8, 117.388, 1955, 66.019}}; double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2}; RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param}); final List<String> xi = new ArrayList<String>(x[0].length); for (int i = 0; i < x.length; i++) { final double[] row = x[i]; for (int j = 0; j < row.length; j++) { xi.add(mhash(featureNames[j]) + ":" + row[j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final Text[] placeholder = new Text[1]; Collector collector = new Collector() { public void collect(Object input) throws HiveException { Object[] forward = (Object[]) input; placeholder[0] = (Text) forward[2]; } }; udtf.setCollector(collector); udtf.close(); Text modelTxt = placeholder[0]; Assert.assertNotNull(modelTxt); byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength()); RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true); return node; }
Example #23
Source File: RandomForestRegressionUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
private static RegressionTree.Node getRegressionTreeFromDenseInput() throws IOException, ParseException, HiveException { double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323}, {259.426, 232.5, 145.6, 108.632, 1948, 61.122}, {258.054, 368.2, 161.6, 109.773, 1949, 60.171}, {284.599, 335.1, 165.0, 110.929, 1950, 61.187}, {328.975, 209.9, 309.9, 112.075, 1951, 63.221}, {346.999, 193.2, 359.4, 113.270, 1952, 63.639}, {365.385, 187.0, 354.7, 115.094, 1953, 64.989}, {363.112, 357.8, 335.0, 116.219, 1954, 63.761}, {397.469, 290.4, 304.8, 117.388, 1955, 66.019}}; double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2}; RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param}); final List<Double> xi = new ArrayList<Double>(x[0].length); for (int i = 0; i < x.length; i++) { for (int j = 0; j < x[i].length; j++) { xi.add(j, x[i][j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final Text[] placeholder = new Text[1]; Collector collector = new Collector() { public void collect(Object input) throws HiveException { Object[] forward = (Object[]) input; placeholder[0] = (Text) forward[2]; } }; udtf.setCollector(collector); udtf.close(); Text modelTxt = placeholder[0]; Assert.assertNotNull(modelTxt); byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength()); RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true); return node; }
Example #24
Source File: RandomForestRegressionUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testSparse() throws IOException, ParseException, HiveException { String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"}; double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323}, {259.426, 232.5, 145.6, 108.632, 1948, 61.122}, {258.054, 368.2, 161.6, 109.773, 1949, 60.171}, {284.599, 335.1, 165.0, 110.929, 1950, 61.187}, {328.975, 209.9, 309.9, 112.075, 1951, 63.221}, {346.999, 193.2, 359.4, 113.270, 1952, 63.639}, {365.385, 187.0, 354.7, 115.094, 1953, 64.989}, {363.112, 357.8, 335.0, 116.219, 1954, 63.761}, {397.469, 290.4, 304.8, 117.388, 1955, 66.019}}; double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2}; RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param}); final List<String> xi = new ArrayList<String>(x[0].length); for (int i = 0; i < x.length; i++) { double[] row = x[i]; for (int j = 0; j < row.length; j++) { xi.add(mhash(featureNames[j]) + ":" + row[j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final MutableInt count = new MutableInt(0); Collector collector = new Collector() { public void collect(Object input) throws HiveException { count.addValue(1); } }; udtf.setCollector(collector); udtf.close(); Assert.assertEquals(49, count.getValue()); }
Example #25
Source File: RandomForestRegressionUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testDense() throws IOException, ParseException, HiveException { double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323}, {259.426, 232.5, 145.6, 108.632, 1948, 61.122}, {258.054, 368.2, 161.6, 109.773, 1949, 60.171}, {284.599, 335.1, 165.0, 110.929, 1950, 61.187}, {328.975, 209.9, 309.9, 112.075, 1951, 63.221}, {346.999, 193.2, 359.4, 113.270, 1952, 63.639}, {365.385, 187.0, 354.7, 115.094, 1953, 64.989}, {363.112, 357.8, 335.0, 116.219, 1954, 63.761}, {397.469, 290.4, 304.8, 117.388, 1955, 66.019}}; double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2}; RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param}); final List<Double> xi = new ArrayList<Double>(x[0].length); for (int i = 0; i < x.length; i++) { for (int j = 0; j < x[i].length; j++) { xi.add(j, x[i][j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final MutableInt count = new MutableInt(0); Collector collector = new Collector() { public void collect(Object input) throws HiveException { count.addValue(1); } }; udtf.setCollector(collector); udtf.close(); Assert.assertEquals(49, count.getValue()); }
Example #26
Source File: GradientTreeBoostingClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testIrisSparse() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); final List<String> xi = new ArrayList<String>(x[0].length); for (int i = 0; i < size; i++) { double[] row = x[i]; for (int j = 0; j < row.length; j++) { xi.add(j + ":" + row[j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final MutableInt count = new MutableInt(0); Collector collector = new Collector() { public void collect(Object input) throws HiveException { count.addValue(1); } }; udtf.setCollector(collector); udtf.close(); Assert.assertEquals(490, count.getValue()); }
Example #27
Source File: GradientTreeBoostingClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testIrisDense() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); GradientTreeBoostingClassifierUDTF udtf = new GradientTreeBoostingClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 490"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); final List<Double> xi = new ArrayList<Double>(x[0].length); for (int i = 0; i < size; i++) { for (int j = 0; j < x[i].length; j++) { xi.add(j, x[i][j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final MutableInt count = new MutableInt(0); Collector collector = new Collector() { public void collect(Object input) throws HiveException { count.addValue(1); } }; udtf.setCollector(collector); udtf.close(); Assert.assertEquals(490, count.getValue()); }
Example #28
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testNews20BinarySparse() throws IOException, ParseException, HiveException { final int numTrees = 10; RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-seed 71 -trees " + numTrees); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); BufferedReader news20 = readFile("news20-small.binary.gz"); ArrayList<String> features = new ArrayList<String>(); String line = news20.readLine(); while (line != null) { StringTokenizer tokens = new StringTokenizer(line, " "); int label = Integer.parseInt(tokens.nextToken()); if (label == -1) { label = 0; } while (tokens.hasMoreTokens()) { features.add(tokens.nextToken()); } if (!features.isEmpty()) { udtf.process(new Object[] {features, label}); features.clear(); } line = news20.readLine(); } news20.close(); final MutableInt count = new MutableInt(0); final MutableInt oobErrors = new MutableInt(0); final MutableInt oobTests = new MutableInt(0); Collector collector = new Collector() { public synchronized void collect(Object input) throws HiveException { Object[] forward = (Object[]) input; oobErrors.addValue(((IntWritable) forward[4]).get()); oobTests.addValue(((IntWritable) forward[5]).get()); count.addValue(1); } }; udtf.setCollector(collector); udtf.close(); Assert.assertEquals(numTrees, count.getValue()); float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue(); Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.3); }
Example #29
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
@Test public void testNews20MultiClassSparse() throws IOException, ParseException, HiveException { final int numTrees = 10; RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-stratified_sampling -seed 71 -trees " + numTrees); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); BufferedReader news20 = readFile("news20-multiclass.gz"); ArrayList<String> features = new ArrayList<String>(); String line = news20.readLine(); while (line != null) { StringTokenizer tokens = new StringTokenizer(line, " "); int label = Integer.parseInt(tokens.nextToken()); while (tokens.hasMoreTokens()) { features.add(tokens.nextToken()); } Assert.assertFalse(features.isEmpty()); udtf.process(new Object[] {features, label}); features.clear(); line = news20.readLine(); } news20.close(); final MutableInt count = new MutableInt(0); final MutableInt oobErrors = new MutableInt(0); final MutableInt oobTests = new MutableInt(0); Collector collector = new Collector() { public synchronized void collect(Object input) throws HiveException { Object[] forward = (Object[]) input; oobErrors.addValue(((IntWritable) forward[4]).get()); oobTests.addValue(((IntWritable) forward[5]).get()); count.addValue(1); } }; udtf.setCollector(collector); udtf.close(); Assert.assertEquals(numTrees, count.getValue()); float oobErrorRate = ((float) oobErrors.getValue()) / oobTests.getValue(); // TODO why multi-class classification so bad?? Assert.assertTrue("oob error rate is too high: " + oobErrorRate, oobErrorRate < 0.8); }
Example #30
Source File: RandomForestClassifierUDTFTest.java From incubator-hivemall with Apache License 2.0 | 4 votes |
private static DecisionTree.Node getDecisionTreeFromSparseInput(String urlString) throws IOException, ParseException, HiveException { URL url = new URL(urlString); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); int size = iris.size(); double[][] x = iris.toArray(new double[size][]); int[] y = iris.toArray(new int[size]); RandomForestClassifierUDTF udtf = new RandomForestClassifierUDTF(); ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71"); udtf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaIntObjectInspector, param}); final List<String> xi = new ArrayList<String>(x[0].length); for (int i = 0; i < size; i++) { final double[] row = x[i]; for (int j = 0; j < row.length; j++) { xi.add(j + ":" + row[j]); } udtf.process(new Object[] {xi, y[i]}); xi.clear(); } final Text[] placeholder = new Text[1]; Collector collector = new Collector() { public void collect(Object input) throws HiveException { Object[] forward = (Object[]) input; placeholder[0] = (Text) forward[2]; } }; udtf.setCollector(collector); udtf.close(); Text modelTxt = placeholder[0]; Assert.assertNotNull(modelTxt); byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength()); DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true); return node; }