org.datavec.api.writable.IntWritable Java Examples
The following examples show how to use
org.datavec.api.writable.IntWritable.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: LibSvmRecordWriterTest.java From DataVec with Apache License 2.0 | 6 votes |
@Test(expected = NumberFormatException.class) public void nonBinaryMultilabel() throws Exception { List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true); FileSplit outputSplit = new FileSplit(tempFile); writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.write(record); } }
Example #2
Source File: TextToCharacterIndexTransform.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override protected List<List<Writable>> expandTimeStep(List<Writable> currentStepValues) { if(writableMap == null){ Map<Character,List<Writable>> m = new HashMap<>(); for(Map.Entry<Character,Integer> entry : characterIndexMap.entrySet()){ m.put(entry.getKey(), Collections.<Writable>singletonList(new IntWritable(entry.getValue()))); } writableMap = m; } List<List<Writable>> out = new ArrayList<>(); char[] cArr = currentStepValues.get(0).toString().toCharArray(); for( char c : cArr ){ List<Writable> w = writableMap.get(c); if(w == null ){ if(exceptionOnUnknown){ throw new IllegalStateException("Unknown character found in text: \"" + c + "\""); } continue; } out.add(w); } return out; }
Example #3
Source File: ArrowConverterTest.java From DataVec with Apache License 2.0 | 6 votes |
@Test public void testArrowBatchSet() { Schema.Builder schema = new Schema.Builder(); List<String> single = new ArrayList<>(); for(int i = 0; i < 2; i++) { schema.addColumnInteger(String.valueOf(i)); single.add(String.valueOf(i)); } List<List<Writable>> input = Arrays.asList( Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2),new IntWritable(3)) ); List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input); ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build()); List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)); writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5))); List<Writable> recordTest = writableRecordBatch.get(1); assertEquals(assertion,recordTest); }
Example #4
Source File: TransformProcessRecordReaderTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void simpleTransformTestSequence() { List<List<Writable>> sequence = new ArrayList<>(); //First window: sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0), new IntWritable(0))); sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1), new IntWritable(0))); sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2), new IntWritable(0))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) .addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build(); TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build(); InMemorySequenceRecordReader inMemorySequenceRecordReader = new InMemorySequenceRecordReader(Arrays.asList(sequence)); TransformProcessSequenceRecordReader transformProcessSequenceRecordReader = new TransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess); List<List<Writable>> next = transformProcessSequenceRecordReader.sequenceRecord(); assertEquals(2, next.get(0).size()); }
Example #5
Source File: ExcelRecordWriterTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
private Triple<String,Schema,List<List<Writable>>> records() { List<List<Writable>> list = new ArrayList<>(); StringBuilder sb = new StringBuilder(); int numColumns = 3; for (int i = 0; i < 10; i++) { List<Writable> temp = new ArrayList<>(); for (int j = 0; j < numColumns; j++) { int v = 100 * i + j; temp.add(new IntWritable(v)); sb.append(v); if (j < 2) sb.append(","); else if (i != 9) sb.append("\n"); } list.add(temp); } Schema.Builder schemaBuilder = new Schema.Builder(); for(int i = 0; i < numColumns; i++) { schemaBuilder.addColumnInteger(String.valueOf(i)); } return Triple.of(sb.toString(),schemaBuilder.build(),list); }
Example #6
Source File: ExcelRecordWriterTest.java From DataVec with Apache License 2.0 | 6 votes |
private Triple<String,Schema,List<List<Writable>>> records() { List<List<Writable>> list = new ArrayList<>(); StringBuilder sb = new StringBuilder(); int numColumns = 3; for (int i = 0; i < 10; i++) { List<Writable> temp = new ArrayList<>(); for (int j = 0; j < numColumns; j++) { int v = 100 * i + j; temp.add(new IntWritable(v)); sb.append(v); if (j < 2) sb.append(","); else if (i != 9) sb.append("\n"); } list.add(temp); } Schema.Builder schemaBuilder = new Schema.Builder(); for(int i = 0; i < numColumns; i++) { schemaBuilder.addColumnInteger(String.valueOf(i)); } return Triple.of(sb.toString(),schemaBuilder.build(),list); }
Example #7
Source File: CSVRecordReaderTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test(expected = NoSuchElementException.class) public void testCsvSkipAllLines() throws IOException, InterruptedException { final int numLines = 4; final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three")); String header = ",one,two,three"; List<String> lines = new ArrayList<>(); for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header); File tempFile = File.createTempFile("csvSkipLines", ".csv"); FileUtils.writeLines(tempFile, lines); CSVRecordReader rr = new CSVRecordReader(numLines, ','); rr.initialize(new FileSplit(tempFile)); rr.reset(); assertTrue(!rr.hasNext()); rr.next(); }
Example #8
Source File: LibSvmRecordWriterTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test(expected = NumberFormatException.class) public void nonIntegerMultilabel() throws Exception { List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); FileSplit outputSplit = new FileSplit(tempFile); writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.write(record); } }
Example #9
Source File: RecordMapperTest.java From DataVec with Apache License 2.0 | 6 votes |
private Triple<String,Schema,List<List<Writable>>> records() { List<List<Writable>> list = new ArrayList<>(); StringBuilder sb = new StringBuilder(); int numColumns = 3; for (int i = 0; i < 10; i++) { List<Writable> temp = new ArrayList<>(); for (int j = 0; j < numColumns; j++) { int v = 100 * i + j; temp.add(new IntWritable(v)); sb.append(v); if (j < 2) sb.append(","); else if (i != 9) sb.append("\n"); } list.add(temp); } Schema.Builder schemaBuilder = new Schema.Builder(); for(int i = 0; i < numColumns; i++) { schemaBuilder.addColumnInteger(String.valueOf(i)); } return Triple.of(sb.toString(),schemaBuilder.build(),list); }
Example #10
Source File: TextToTermIndexSequenceTransform.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override protected List<List<Writable>> expandTimeStep(List<Writable> currentStepValues) { if(writableMap == null){ Map<String,List<Writable>> m = new HashMap<>(); for(Map.Entry<String,Integer> entry : wordIndexMap.entrySet()) { m.put(entry.getKey(), Collections.<Writable>singletonList(new IntWritable(entry.getValue()))); } writableMap = m; } List<List<Writable>> out = new ArrayList<>(); String text = currentStepValues.get(0).toString(); String[] tokens = text.split(this.delimiter); for(String token : tokens ){ List<Writable> w = writableMap.get(token); if(w == null) { if(exceptionOnUnknown) { throw new IllegalStateException("Unknown token found in text: \"" + token + "\""); } continue; } out.add(w); } return out; }
Example #11
Source File: TransformProcessRecordReaderTests.java From DataVec with Apache License 2.0 | 6 votes |
@Test public void simpleTransformTestSequence() { List<List<Writable>> sequence = new ArrayList<>(); //First window: sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0), new IntWritable(0))); sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1), new IntWritable(0))); sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2), new IntWritable(0))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) .addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build(); TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build(); InMemorySequenceRecordReader inMemorySequenceRecordReader = new InMemorySequenceRecordReader(Arrays.asList(sequence)); TransformProcessSequenceRecordReader transformProcessSequenceRecordReader = new TransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess); List<List<Writable>> next = transformProcessSequenceRecordReader.sequenceRecord(); assertEquals(2, next.get(0).size()); }
Example #12
Source File: TestRecordReaders.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testClassIndexOutsideOfRangeRRMDSI() { Collection<Collection<Collection<Writable>>> c = new ArrayList<>(); Collection<Collection<Writable>> seq1 = new ArrayList<>(); seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0))); seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(1))); c.add(seq1); Collection<Collection<Writable>> seq2 = new ArrayList<>(); seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0))); seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2))); c.add(seq2); CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c); DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, 2, 2, 1); try { DataSet ds = dsi.next(); fail("Expected exception"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("to one-hot")); } }
Example #13
Source File: LibSvmRecordWriterTest.java From DataVec with Apache License 2.0 | 6 votes |
@Test(expected = NumberFormatException.class) public void nonIntegerMultilabel() throws Exception { List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); FileSplit outputSplit = new FileSplit(tempFile); writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.write(record); } }
Example #14
Source File: LibSvmRecordWriterTest.java From DataVec with Apache License 2.0 | 6 votes |
@Test public void testNonIntegerButValidMultilabel() throws Exception { List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); FileSplit outputSplit = new FileSplit(tempFile); writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.write(record); } }
Example #15
Source File: LongMetaData.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Override public boolean isValid(Writable writable) { long value; if (writable instanceof IntWritable || writable instanceof LongWritable) { value = writable.toLong(); } else { try { value = Long.parseLong(writable.toString()); } catch (NumberFormatException e) { return false; } } if (minAllowedValue != null && value < minAllowedValue) return false; if (maxAllowedValue != null && value > maxAllowedValue) return false; return true; }
Example #16
Source File: TestFilters.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testFilterNumColumns() { List<List<Writable>> list = new ArrayList<>(); list.add(Collections.singletonList((Writable) new IntWritable(-1))); list.add(Collections.singletonList((Writable) new IntWritable(0))); list.add(Collections.singletonList((Writable) new IntWritable(2))); Schema schema = new Schema.Builder().addColumnInteger("intCol", 0, 10) //Only values in the range 0 to 10 are ok .addColumnDouble("doubleCol", -100.0, 100.0) //-100 to 100 only; no NaN or infinite .build(); Filter numColumns = new InvalidNumColumns(schema); for (int i = 0; i < list.size(); i++) assertTrue(numColumns.removeExample(list.get(i))); List<Writable> correct = Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(2)); assertFalse(numColumns.removeExample(correct)); }
Example #17
Source File: RecordReaderDataSetiteratorTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testRecordReaderDataSetIteratorConcat() { //[DoubleWritable, DoubleWritable, NDArrayWritable([1,10]), IntWritable] -> concatenate to a [1,13] feature vector automatically. List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] {2, 3, 4})), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] {6, 7, 8})), new IntWritable(9), new IntWritable(1)); RecordReader rr = new CollectionRecordReader(Collections.singletonList(l)); DataSetIterator iter = new RecordReaderDataSetIterator(rr, 1, 5, 3); DataSet ds = iter.next(); INDArray expF = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[]{1,9}); INDArray expL = Nd4j.create(new float[] {0, 1, 0}, new int[]{1,3}); assertEquals(expF, ds.getFeatures()); assertEquals(expL, ds.getLabels()); }
Example #18
Source File: TextToCharacterIndexTransform.java From DataVec with Apache License 2.0 | 6 votes |
@Override protected List<List<Writable>> expandTimeStep(List<Writable> currentStepValues) { if(writableMap == null){ Map<Character,List<Writable>> m = new HashMap<>(); for(Map.Entry<Character,Integer> entry : characterIndexMap.entrySet()){ m.put(entry.getKey(), Collections.<Writable>singletonList(new IntWritable(entry.getValue()))); } writableMap = m; } List<List<Writable>> out = new ArrayList<>(); char[] cArr = currentStepValues.get(0).toString().toCharArray(); for( char c : cArr ){ List<Writable> w = writableMap.get(c); if(w == null ){ if(exceptionOnUnknown){ throw new IllegalStateException("Unknown character found in text: \"" + c + "\""); } continue; } out.add(w); } return out; }
Example #19
Source File: LongMetaData.java From DataVec with Apache License 2.0 | 6 votes |
@Override public boolean isValid(Writable writable) { long value; if (writable instanceof IntWritable || writable instanceof LongWritable) { value = writable.toLong(); } else { try { value = Long.parseLong(writable.toString()); } catch (NumberFormatException e) { return false; } } if (minAllowedValue != null && value < minAllowedValue) return false; if (maxAllowedValue != null && value > maxAllowedValue) return false; return true; }
Example #20
Source File: TestFilters.java From DataVec with Apache License 2.0 | 6 votes |
@Test public void testFilterNumColumns() { List<List<Writable>> list = new ArrayList<>(); list.add(Collections.singletonList((Writable) new IntWritable(-1))); list.add(Collections.singletonList((Writable) new IntWritable(0))); list.add(Collections.singletonList((Writable) new IntWritable(2))); Schema schema = new Schema.Builder().addColumnInteger("intCol", 0, 10) //Only values in the range 0 to 10 are ok .addColumnDouble("doubleCol", -100.0, 100.0) //-100 to 100 only; no NaN or infinite .build(); Filter numColumns = new InvalidNumColumns(schema); for (int i = 0; i < list.size(); i++) assertTrue(numColumns.removeExample(list.get(i))); List<Writable> correct = Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(2)); assertFalse(numColumns.removeExample(correct)); }
Example #21
Source File: TestWritablesToNDArrayFunction.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testWritablesToNDArrayMixed() throws Exception { Nd4j.setDataType(DataType.FLOAT); List<Writable> l = new ArrayList<>(); l.add(new IntWritable(0)); l.add(new IntWritable(1)); INDArray arr = Nd4j.arange(2, 5).reshape(1, 3); l.add(new NDArrayWritable(arr)); l.add(new IntWritable(5)); arr = Nd4j.arange(6, 9).reshape(1, 3); l.add(new NDArrayWritable(arr)); l.add(new IntWritable(9)); INDArray expected = Nd4j.arange(10).castTo(DataType.FLOAT).reshape(1, 10); assertEquals(expected, new WritablesToNDArrayFunction().apply(l)); }
Example #22
Source File: TestRecordReaders.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testClassIndexOutsideOfRangeRRDSI() { Collection<Collection<Writable>> c = new ArrayList<>(); c.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new IntWritable(0))); c.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new IntWritable(2))); CollectionRecordReader crr = new CollectionRecordReader(c); RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2); try { DataSet ds = iter.next(); fail("Expected exception"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("to one-hot")); } }
Example #23
Source File: LocalTransformProcessRecordReaderTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void simpleTransformTestSequence() { List<List<Writable>> sequence = new ArrayList<>(); //First window: sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0), new IntWritable(0))); sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1), new IntWritable(0))); sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2), new IntWritable(0))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) .addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build(); TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build(); InMemorySequenceRecordReader inMemorySequenceRecordReader = new InMemorySequenceRecordReader(Arrays.asList(sequence)); LocalTransformProcessSequenceRecordReader transformProcessSequenceRecordReader = new LocalTransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess); List<List<Writable>> next = transformProcessSequenceRecordReader.sequenceRecord(); assertEquals(2, next.get(0).size()); }
Example #24
Source File: LibSvmRecordWriterTest.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test public void testNonIntegerButValidMultilabel() throws Exception { List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); tempFile.setWritable(true); tempFile.deleteOnExit(); if (tempFile.exists()) tempFile.delete(); try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) { Configuration configWriter = new Configuration(); configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0); configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1); configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true); FileSplit outputSplit = new FileSplit(tempFile); writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner()); writer.write(record); } }
Example #25
Source File: TestFilters.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void testConditionFilter() { Schema schema = new Schema.Builder().addColumnInteger("column").build(); Condition condition = new IntegerColumnCondition("column", ConditionOp.LessThan, 0); condition.setInputSchema(schema); Filter filter = new ConditionFilter(condition); assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10)))); assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1)))); assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0)))); assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1)))); assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10)))); }
Example #26
Source File: ExecutionTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test(timeout = 60000L) @Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") public void testPythonExecution() throws Exception { Schema schema = new Schema.Builder().addColumnInteger("col0") .addColumnString("col1").addColumnDouble("col2").build(); Schema finalSchema = new Schema.Builder().addColumnInteger("col0") .addColumnInteger("col1").addColumnDouble("col2").build(); String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0"; TransformProcess tp = new TransformProcess.Builder(schema).transform( PythonTransform.builder().code( "first = np.sin(first)\nsecond = np.cos(second)") .outputSchema(finalSchema).build() ).build(); List<List<Writable>> inputData = new ArrayList<>(); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); Collections.sort(out, new Comparator<List<Writable>>() { @Override public int compare(List<Writable> o1, List<Writable> o2) { return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); } }); List<List<Writable>> expected = new ArrayList<>(); expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); assertEquals(expected, out); }
Example #27
Source File: ConvertToInteger.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public IntWritable map(Writable writable) { if(writable.getType() == WritableType.Int){ return (IntWritable)writable; } return new IntWritable(writable.toInt()); }
Example #28
Source File: SpecialImageRecordReader.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Override public List<Writable> next() { INDArray features = Nd4j.create(channels, height, width); fillNDArray(features, counter.getAndIncrement()); features = features.reshape(1, channels, height, width); List<Writable> ret = RecordConverter.toRecord(features); ret.add(new IntWritable(RandomUtils.nextInt(0, numClasses))); return ret; }
Example #29
Source File: LibSvmRecordReaderTest.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void testBasicRecord() throws IOException, InterruptedException { Map<Integer, List<Writable>> correct = new HashMap<>(); // 7 2:1 4:2 6:3 8:4 10:5 correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7))); // 2 qid:42 1:0.1 2:2 6:6.6 8:80 correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2))); // 33 correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33))); LibSvmRecordReader rr = new LibSvmRecordReader(); Configuration config = new Configuration(); config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false); config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true); config.setInt(LibSvmRecordReader.NUM_FEATURES, 10); rr.initialize(config, new FileSplit(new ClassPathResource("svmlight/basic.txt").getFile())); int i = 0; while (rr.hasNext()) { List<Writable> record = rr.next(); assertEquals(correct.get(i), record); i++; } assertEquals(i, correct.size()); }
Example #30
Source File: ExecutionTest.java From DataVec with Apache License 2.0 | 5 votes |
@Test public void testUniqueMultiCol(){ Schema schema = new Schema.Builder() .addColumnInteger("col0") .addColumnCategorical("col1", "state0", "state1", "state2") .addColumnDouble("col2").build(); List<List<Writable>> inputData = new ArrayList<>(); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); Map<String,List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd); assertEquals(2, l.size()); List<Writable> c0 = l.get("col0"); assertEquals(3, c0.size()); assertTrue(c0.contains(new IntWritable(0)) && c0.contains(new IntWritable(1)) && c0.contains(new IntWritable(2))); List<Writable> c1 = l.get("col1"); assertEquals(3, c1.size()); assertTrue(c1.contains(new Text("state0")) && c1.contains(new Text("state1")) && c1.contains(new Text("state2"))); }