Python apache_beam.Pipeline() Examples
The following are 30
code examples of apache_beam.Pipeline().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
apache_beam
, or try the search function
.
Example #1
Source File: streaming_beam.py From python-docs-samples with Apache License 2.0 | 6 votes |
def run(args, input_subscription, output_table, window_interval): """Build and run the pipeline.""" options = PipelineOptions(args, save_main_session=True, streaming=True) with beam.Pipeline(options=options) as pipeline: # Read the messages from PubSub and process them. messages = ( pipeline | 'Read from Pub/Sub' >> beam.io.ReadFromPubSub( subscription=input_subscription).with_output_types(bytes) | 'UTF-8 bytes to string' >> beam.Map(lambda msg: msg.decode('utf-8')) | 'Parse JSON messages' >> beam.Map(parse_json_message) | 'Fixed-size windows' >> beam.WindowInto( window.FixedWindows(int(window_interval), 0)) | 'Add URL keys' >> beam.Map(lambda msg: (msg['url'], msg)) | 'Group by URLs' >> beam.GroupByKey() | 'Get statistics' >> beam.Map(get_statistics)) # Output the results into BigQuery table. _ = messages | 'Write to Big Query' >> beam.io.WriteToBigQuery( output_table, schema=SCHEMA)
Example #2
Source File: vcf_to_bq.py From gcp-variant-transforms with Apache License 2.0 | 6 votes |
def _read_variants(all_patterns, # type: List[str] pipeline, # type: beam.Pipeline known_args, # type: argparse.Namespace pipeline_mode, # type: int pre_infer_headers=False, # type: bool keep_raw_sample_names=False ): # type: (...) -> pvalue.PCollection """Helper method for returning a PCollection of Variants from VCFs.""" representative_header_lines = None if known_args.representative_header_file: representative_header_lines = vcf_header_parser.get_metadata_header_lines( known_args.representative_header_file) return pipeline_common.read_variants( pipeline, all_patterns, pipeline_mode, known_args.allow_malformed_records, representative_header_lines, pre_infer_headers=pre_infer_headers, sample_name_encoding=( SampleNameEncoding.NONE if keep_raw_sample_names else SampleNameEncoding[known_args.sample_name_encoding]), use_1_based_coordinate=known_args.use_1_based_coordinate)
Example #3
Source File: pipeline_common.py From gcp-variant-transforms with Apache License 2.0 | 6 votes |
def read_headers( pipeline, #type: beam.Pipeline pipeline_mode, #type: int all_patterns #type: List[str] ): # type: (...) -> pvalue.PCollection """Creates an initial PCollection by reading the VCF file headers.""" compression_type = get_compression_type(all_patterns) if pipeline_mode == PipelineModes.LARGE: headers = (pipeline | beam.Create(all_patterns) | vcf_header_io.ReadAllVcfHeaders( compression_type=compression_type)) else: headers = pipeline | vcf_header_io.ReadVcfHeaders( all_patterns[0], compression_type=compression_type) return headers
Example #4
Source File: beam_utils.py From lingvo with Apache License 2.0 | 6 votes |
def GetPipelineRoot(options=None): """Return the root of the beam pipeline. Typical usage looks like: with GetPipelineRoot() as root: _ = (root | beam.ParDo() | ...) In this example, the pipeline is automatically executed when the context is exited, though one can manually run the pipeline built from the root object as well. Args: options: A beam.options.pipeline_options.PipelineOptions object. Returns: A beam.Pipeline root object. """ return beam.Pipeline(options=options)
Example #5
Source File: stats_api_test.py From data-validation with Apache License 2.0 | 6 votes |
def test_stats_pipeline_with_zero_examples(self): expected_result = text_format.Parse( """ datasets { num_examples: 0 } """, statistics_pb2.DatasetFeatureStatisticsList()) with beam.Pipeline() as p: options = stats_options.StatsOptions( num_top_values=1, num_rank_histogram_buckets=1, num_values_histogram_buckets=2, num_histogram_buckets=1, num_quantiles_histogram_buckets=1, epsilon=0.001) result = (p | beam.Create([]) | stats_api.GenerateStatistics(options)) util.assert_that( result, test_util.make_dataset_feature_stats_list_proto_equal_fn( self, expected_result))
Example #6
Source File: stats_api_test.py From data-validation with Apache License 2.0 | 6 votes |
def test_stats_pipeline_with_sample_rate(self): record_batches = [ pa.RecordBatch.from_arrays( [pa.array([np.linspace(1, 3000, 3000, dtype=np.int32)])], ['c']), ] with beam.Pipeline() as p: options = stats_options.StatsOptions( sample_rate=1.0, num_top_values=2, num_rank_histogram_buckets=2, num_values_histogram_buckets=2, num_histogram_buckets=2, num_quantiles_histogram_buckets=2, epsilon=0.001) result = ( p | beam.Create(record_batches) | stats_api.GenerateStatistics(options)) util.assert_that( result, test_util.make_dataset_feature_stats_list_proto_equal_fn( self, self._sampling_test_expected_result))
Example #7
Source File: stats_api_test.py From data-validation with Apache License 2.0 | 6 votes |
def test_write_stats_to_text(self): stats = text_format.Parse( """ datasets { name: 'x' num_examples: 100 } """, statistics_pb2.DatasetFeatureStatisticsList()) output_path = os.path.join(self._get_temp_dir(), 'stats') with beam.Pipeline() as p: _ = (p | beam.Create([stats]) | stats_api.WriteStatisticsToText( output_path)) stats_from_file = statistics_pb2.DatasetFeatureStatisticsList() serialized_stats = io_util.read_file_to_string( output_path, binary_mode=True) stats_from_file.ParseFromString(serialized_stats) self.assertLen(stats_from_file.datasets, 1) test_util.assert_dataset_feature_stats_proto_equal( self, stats_from_file.datasets[0], stats.datasets[0])
Example #8
Source File: stats_impl_test.py From data-validation with Apache License 2.0 | 6 votes |
def test_stats_impl(self, record_batches, options, expected_result_proto_text, schema=None): expected_result = text_format.Parse( expected_result_proto_text, statistics_pb2.DatasetFeatureStatisticsList()) if schema is not None: options.schema = schema with beam.Pipeline() as p: result = ( p | beam.Create(record_batches, reshuffle=False) | stats_impl.GenerateStatisticsImpl(options)) util.assert_that( result, test_util.make_dataset_feature_stats_list_proto_equal_fn( self, expected_result))
Example #9
Source File: csv_decoder_test.py From data-validation with Apache License 2.0 | 6 votes |
def test_csv_decoder(self, input_lines, expected_result, column_names, delimiter=',', skip_blank_lines=True, schema=None, multivalent_columns=None, secondary_delimiter=None): with beam.Pipeline() as p: result = ( p | beam.Create(input_lines, reshuffle=False) | csv_decoder.DecodeCSV( column_names=column_names, delimiter=delimiter, skip_blank_lines=skip_blank_lines, schema=schema, multivalent_columns=multivalent_columns, secondary_delimiter=secondary_delimiter)) util.assert_that( result, test_util.make_arrow_record_batches_equal_fn(self, expected_result))
Example #10
Source File: _local.py From pydatalab with Apache License 2.0 | 6 votes |
def preprocess(train_dataset, output_dir, eval_dataset, checkpoint): """Preprocess data locally.""" import apache_beam as beam from google.datalab.utils import LambdaJob from . import _preprocess if checkpoint is None: checkpoint = _util._DEFAULT_CHECKPOINT_GSURL job_id = ('preprocess-image-classification-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')) # Project is needed for bigquery data source, even in local run. options = { 'project': _util.default_project(), } opts = beam.pipeline.PipelineOptions(flags=[], **options) p = beam.Pipeline('DirectRunner', options=opts) _preprocess.configure_pipeline(p, train_dataset, eval_dataset, checkpoint, output_dir, job_id) job = LambdaJob(lambda: p.run().wait_until_finish(), job_id) return job
Example #11
Source File: executor_test.py From tfx with Apache License 2.0 | 6 votes |
def testPrestoToExample(self): with beam.Pipeline() as pipeline: examples = ( pipeline | 'ToTFExample' >> executor._PrestoToExample( exec_properties={ 'input_config': json_format.MessageToJson( example_gen_pb2.Input(), preserving_proto_field_name=True), 'custom_config': json_format.MessageToJson( example_gen_pb2.CustomConfig(), preserving_proto_field_name=True) }, split_pattern='SELECT i, f, s FROM `fake`')) feature = {} feature['i'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[1])) feature['f'] = tf.train.Feature( float_list=tf.train.FloatList(value=[2.0])) feature['s'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes('abc')])) example_proto = tf.train.Example( features=tf.train.Features(feature=feature)) util.assert_that(examples, util.equal_to([example_proto]))
Example #12
Source File: datagen_beam.py From magenta with Apache License 2.0 | 6 votes |
def create_glyphazzn_dataset(filepattern, output_path): """Creates a glyphazzn dataset, from raw Parquetio to TFRecords.""" def pipeline(root): """Pipeline for creating glyphazzn dataset.""" attrs = ['uni', 'width', 'vwidth', 'sfd', 'id', 'binary_fp'] examples = root | 'Read' >> beam.io.parquetio.ReadFromParquet( file_pattern=filepattern, columns=attrs) examples = examples | 'FilterBadIcons' >> beam.Filter(_is_valid_glyph) examples = examples | 'ConvertToPath' >> beam.Map(_convert_to_path) examples = examples | 'FilterBadPathLenghts' >> beam.Filter(_is_valid_path) examples = examples | 'ProcessAndConvert' >> beam.Map(_create_example) (examples | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord( output_path, num_shards=90)) return pipeline
Example #13
Source File: tft_unit.py From transform with Apache License 2.0 | 6 votes |
def convert_to_tfxio_api_inputs( self, legacy_input_data, legacy_input_metadata, label='input_data'): """Converts from the legacy TFT API inputs to TFXIO-based inputs. Args: legacy_input_data: a PCollection of instance dicts. legacy_input_metadata: a tft.DatasetMetadata. label: label for the PTransform that translates `legacy_input_data` into the TFXIO input data. Set to different values if this method is called multiple times in a beam Pipeline. Returns: A tuple of a PCollection of `pyarrow.RecordBatch` and a `tensor_adapter.TensorAdapterConfig`. This tuple can be fed directly to TFT's `{Analyze,Transform,AnalyzeAndTransform}Dataset` APIs. """ tfxio_impl = _LegacyCompatibilityTFXIO(legacy_input_metadata.schema) input_data = ( legacy_input_data | ('LegacyFormatToTfxio[%s]' % label >> tfxio_impl.BeamSource( beam_impl.Context.get_desired_batch_size()))) return input_data, tfxio_impl.TensorAdapterConfig()
Example #14
Source File: impl.py From transform with Apache License 2.0 | 6 votes |
def _clear_shared_state_after_barrier(pipeline, input_barrier): """Clears any shared state from within a pipeline context. This will only be cleared once input_barrier becomes available. Args: pipeline: A `beam.Pipeline` object. input_barrier: A `PCollection` which the pipeline should wait for. Returns: An empty `PCollection`. """ empty_pcoll = input_barrier | 'MakeCheapBarrier' >> beam.FlatMap( lambda x: None) return (pipeline | 'PrepareToClearSharedKeepAlives' >> beam.Create([None]) | 'WaitAndClearSharedKeepAlives' >> beam.Map( lambda x, empty_side_input: shared.Shared().acquire(lambda: None), beam.pvalue.AsIter(empty_pcoll)))
Example #15
Source File: transform_fn_io_test.py From transform with Apache License 2.0 | 6 votes |
def testReadTransformFn(self): path = self.get_temp_dir() # NOTE: we don't need to create or write to the transform_fn directory since # ReadTransformFn never inspects this directory. transform_fn_dir = os.path.join( path, tft.TFTransformOutput.TRANSFORM_FN_DIR) transformed_metadata_dir = os.path.join( path, tft.TFTransformOutput.TRANSFORMED_METADATA_DIR) metadata_io.write_metadata(test_metadata.COMPLETE_METADATA, transformed_metadata_dir) with beam.Pipeline() as pipeline: saved_model_dir_pcoll, metadata = ( pipeline | transform_fn_io.ReadTransformFn(path)) beam_test_util.assert_that( saved_model_dir_pcoll, beam_test_util.equal_to([transform_fn_dir]), label='AssertSavedModelDir') # NOTE: metadata is currently read in a non-deferred manner. self.assertEqual(metadata, test_metadata.COMPLETE_METADATA)
Example #16
Source File: transform_fn_io_test.py From transform with Apache License 2.0 | 6 votes |
def testWriteTransformFnIsIdempotent(self): transform_output_dir = os.path.join(self.get_temp_dir(), 'output') def mock_write_metadata_expand(unused_self, unused_metadata): raise ArithmeticError('Some error') with beam.Pipeline() as pipeline: # Create an empty directory for the source saved model dir. saved_model_dir = os.path.join(self.get_temp_dir(), 'source') saved_model_dir_pcoll = ( pipeline | 'CreateSavedModelDir' >> beam.Create([saved_model_dir])) with mock.patch.object(transform_fn_io.beam_metadata_io.WriteMetadata, 'expand', mock_write_metadata_expand): with self.assertRaisesRegexp(ArithmeticError, 'Some error'): _ = ((saved_model_dir_pcoll, object()) | transform_fn_io.WriteTransformFn(transform_output_dir)) self.assertFalse(file_io.file_exists(transform_output_dir))
Example #17
Source File: _local.py From pydatalab with Apache License 2.0 | 6 votes |
def batch_predict(dataset, model_dir, output_csv, output_bq_table): """Batch predict running locally.""" import apache_beam as beam from google.datalab.utils import LambdaJob from . import _predictor if output_csv is None and output_bq_table is None: raise ValueError('output_csv and output_bq_table cannot both be None.') job_id = ('batch-predict-image-classification-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')) # Project is needed for bigquery data source, even in local run. options = { 'project': _util.default_project(), } opts = beam.pipeline.PipelineOptions(flags=[], **options) p = beam.Pipeline('DirectRunner', options=opts) _predictor.configure_pipeline(p, dataset, model_dir, output_csv, output_bq_table) job = LambdaJob(lambda: p.run().wait_until_finish(), job_id) return job
Example #18
Source File: executor.py From tfx with Apache License 2.0 | 6 votes |
def _PrestoToExample( # pylint: disable=invalid-name pipeline: beam.Pipeline, exec_properties: Dict[Text, Any], split_pattern: Text) -> beam.pvalue.PCollection: """Read from Presto and transform to TF examples. Args: pipeline: beam pipeline. exec_properties: A dict of execution properties. split_pattern: Split.pattern in Input config, a Presto sql string. Returns: PCollection of TF examples. """ conn_config = example_gen_pb2.CustomConfig() json_format.Parse(exec_properties['custom_config'], conn_config) presto_config = presto_config_pb2.PrestoConnConfig() conn_config.custom_config.Unpack(presto_config) client = _deserialize_conn_config(presto_config) return (pipeline | 'Query' >> beam.Create([split_pattern]) | 'QueryTable' >> beam.ParDo(_ReadPrestoDoFn(client)) | 'ToTFExample' >> beam.Map(_row_to_example))
Example #19
Source File: PubSubToGCS.py From python-docs-samples with Apache License 2.0 | 6 votes |
def run(input_topic, output_path, window_size=1.0, pipeline_args=None): # `save_main_session` is set to true because some DoFn's rely on # globally imported modules. pipeline_options = PipelineOptions( pipeline_args, streaming=True, save_main_session=True ) with beam.Pipeline(options=pipeline_options) as pipeline: ( pipeline | "Read PubSub Messages" >> beam.io.ReadFromPubSub(topic=input_topic) | "Window into" >> GroupWindowsIntoBatches(window_size) | "Write to GCS" >> beam.ParDo(WriteBatchesToGCS(output_path)) )
Example #20
Source File: bigquery_to_gcs_lib.py From healthcare-deid with Apache License 2.0 | 5 votes |
def run_pipeline(input_query, output_file, pipeline_args): p = beam.Pipeline(options=PipelineOptions(pipeline_args)) _ = (p | 'read' >> beam.io.Read(beam.io.BigQuerySource(query=input_query)) | 'to_physionet' >> beam.Map(map_to_physionet_record) | 'write' >> beam.io.WriteToText(output_file)) result = p.run().wait_until_finish() logging.info('BigQuery to GCS result: %s', result)
Example #21
Source File: gcs_to_bigquery_lib.py From healthcare-deid with Apache License 2.0 | 5 votes |
def run_pipeline(input_pattern, output_table, pipeline_args): """Read the records from GCS and write them to BigQuery.""" p = beam.Pipeline(options=PipelineOptions(pipeline_args)) _ = (p | 'match_files' >> beam.Create(f2pn.match_files(input_pattern)) | 'to_records' >> beam.FlatMap(map_file_to_records) | 'map_to_bq_inputs' >> beam.Map(map_to_bq_inputs) | 'write' >> beam.io.Write(beam.io.BigQuerySink( output_table, schema='patient_id:INTEGER, note:STRING', write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE))) result = p.run().wait_until_finish() logging.info('GCS to BigQuery result: %s', result)
Example #22
Source File: vcf_to_bq.py From gcp-variant-transforms with Apache License 2.0 | 5 votes |
def _add_inferred_headers(all_patterns, # type: List[str] pipeline, # type: beam.Pipeline known_args, # type: argparse.Namespace merged_header, # type: pvalue.PCollection pipeline_mode # type: int ): # type: (...) -> pvalue.PCollection annotation_fields_to_infer = (known_args.annotation_fields if known_args.infer_annotation_types else []) inferred_headers = ( _read_variants(all_patterns, pipeline, known_args, pipeline_mode, pre_infer_headers=known_args.infer_headers) | 'FilterVariants' >> filter_variants.FilterVariants( reference_names=known_args.reference_names) | 'InferHeaderFields' >> infer_headers.InferHeaderFields( pvalue.AsSingleton(merged_header), known_args.allow_incompatible_records, known_args.infer_headers, annotation_fields_to_infer)) merged_header = ( (inferred_headers, merged_header) | 'FlattenHeaders' >> beam.Flatten() | 'MergeHeadersFromVcfAndVariants' >> merge_headers.MergeHeaders( known_args.split_alternate_allele_info_fields, known_args.allow_incompatible_records)) return merged_header
Example #23
Source File: tft_unit.py From transform with Apache License 2.0 | 5 votes |
def metrics(self): if not self.has_ran: raise RuntimeError('Pipeline has to run before accessing its metrics') return self._run_result.metrics()
Example #24
Source File: sentiment_example.py From transform with Apache License 2.0 | 5 votes |
def read_and_shuffle_data( train_neg_filepattern, train_pos_filepattern, test_neg_filepattern, test_pos_filepattern, working_dir): """Read and shuffle the data and write out as a TFRecord of Example protos. Read in the data from the positive and negative examples on disk, shuffle it and write it out in TFRecord format. transform it using a preprocessing pipeline that removes punctuation, tokenizes and maps tokens to int64 values indices. Args: train_neg_filepattern: Filepattern for training data negative examples train_pos_filepattern: Filepattern for training data positive examples test_neg_filepattern: Filepattern for test data negative examples test_pos_filepattern: Filepattern for test data positive examples working_dir: Directory to write shuffled data to """ with beam.Pipeline() as pipeline: coder = tft.coders.ExampleProtoCoder(RAW_DATA_METADATA.schema) # pylint: disable=no-value-for-parameter _ = ( pipeline | 'ReadAndShuffleTrain' >> ReadAndShuffleData( (train_neg_filepattern, train_pos_filepattern)) | 'EncodeTrainData' >> beam.Map(coder.encode) | 'WriteTrainData' >> beam.io.WriteToTFRecord( os.path.join(working_dir, SHUFFLED_TRAIN_DATA_FILEBASE))) _ = ( pipeline | 'ReadAndShuffleTest' >> ReadAndShuffleData( (test_neg_filepattern, test_pos_filepattern)) | 'EncodeTestData' >> beam.Map(coder.encode) | 'WriteTestData' >> beam.io.WriteToTFRecord( os.path.join(working_dir, SHUFFLED_TEST_DATA_FILEBASE))) # pylint: enable=no-value-for-parameter
Example #25
Source File: staffline_patches_dofn_test.py From moonlight with Apache License 2.0 | 5 votes |
def testPipeline_corpusImage(self): filename = os.path.join(tf.resource_loader.get_data_files_path(), '../../testdata/IMSLP00747-000.png') with tempfile.NamedTemporaryFile() as output_examples: # Run the pipeline to get the staffline patches. with beam.Pipeline() as pipeline: dofn = staffline_patches_dofn.StafflinePatchesDoFn( PATCH_HEIGHT, PATCH_WIDTH, NUM_STAFFLINES, TIMEOUT_MS, MAX_PATCHES_PER_PAGE) # pylint: disable=expression-not-assigned (pipeline | beam.transforms.Create([filename]) | beam.transforms.ParDo(dofn) | beam.io.WriteToTFRecord( output_examples.name, beam.coders.ProtoCoder(tf.train.Example), shard_name_template='')) # Get the staffline images from a local TensorFlow session. extractor = staffline_extractor.StafflinePatchExtractor( staffline_extractor.DEFAULT_NUM_SECTIONS, PATCH_HEIGHT, PATCH_WIDTH) with tf.Session(graph=extractor.graph): expected_patches = [ tuple(patch.ravel()) for unused_key, patch in extractor.page_patch_iterator(filename) ] for example_bytes in tf_record.tf_record_iterator(output_examples.name): example = tf.train.Example() example.ParseFromString(example_bytes) patch_pixels = tuple( example.features.feature['features'].float_list.value) if patch_pixels not in expected_patches: self.fail('Missing patch {}'.format(patch_pixels))
Example #26
Source File: pipeline_flags.py From moonlight with Apache License 2.0 | 5 votes |
def create_pipeline(**kwargs): return apache_beam.Pipeline(FLAGS.runner, **kwargs)
Example #27
Source File: onsets_frames_transcription_create_tfrecords.py From magenta with Apache License 2.0 | 5 votes |
def main(argv): del argv flags.mark_flags_as_required(['csv', 'output_directory']) tf.io.gfile.makedirs(FLAGS.output_directory) with tf.io.gfile.GFile(FLAGS.csv) as f: reader = csv.DictReader(f) splits = collections.defaultdict(list) for row in reader: splits[row['split']].append( (os.path.join(FLAGS.midi_dir, row['midi_filename']), os.path.join(FLAGS.wav_dir, row['audio_filename']))) if sorted(splits.keys()) != sorted(FLAGS.expected_splits.split(',')): raise ValueError('Got unexpected set of splits: %s' % list(splits.keys())) pipeline_options = beam.options.pipeline_options.PipelineOptions( FLAGS.pipeline_options) with beam.Pipeline(options=pipeline_options) as p: for split in splits: split_p = p | 'prepare_split_%s' % split >> beam.Create(splits[split]) split_p |= 'create_examples_%s' % split >> beam.ParDo( CreateExampleDoFn(FLAGS.wav_dir, FLAGS.midi_dir, FLAGS.add_wav_glob)) split_p |= 'write_%s' % split >> beam.io.WriteToTFRecord( os.path.join(FLAGS.output_directory, '%s.tfrecord' % split), coder=beam.coders.ProtoCoder(tf.train.Example), num_shards=FLAGS.num_shards)
Example #28
Source File: datagen_beam.py From magenta with Apache License 2.0 | 5 votes |
def main(_): pipeline_options = beam.options.pipeline_options.PipelineOptions( FLAGS.pipeline_options.split(',')) pipeline = create_glyphazzn_dataset( FLAGS.raw_data_file + '*', FLAGS.final_data_file) with beam.Pipeline(options=pipeline_options) as root: pipeline(root) pipeline = get_stats_of_glyphazzn( FLAGS.final_data_file + '*', FLAGS.final_stats_file) with beam.Pipeline(options=pipeline_options) as root: pipeline(root)
Example #29
Source File: datagen_beam.py From magenta with Apache License 2.0 | 5 votes |
def get_stats_of_glyphazzn(filepattern, output_path): """Computes the Mean and Std across examples in glyphazzn dataset.""" def pipeline(root): """Pipeline for computing means/std from dataset.""" examples = root | 'Read' >> beam.io.tfrecordio.ReadFromTFRecord(filepattern) examples = examples | 'Deserialize' >> beam.Map(_decode_tfexample) examples = examples | 'GetMeanStdev' >> beam.CombineGlobally(MeanStddev()) examples = examples | 'MeanStdevToSerializedTFRecord' >> beam.Map( _mean_to_example) (examples | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord( output_path, coder=beam.coders.ProtoCode(tf.train.Example))) return pipeline
Example #30
Source File: beam_metadata_io_test.py From transform with Apache License 2.0 | 5 votes |
def testWriteMetadataDeferred(self): # Write metadata to disk using WriteMetadata PTransform, combining # incomplete metadata with (deferred) complete metadata. with beam.Pipeline() as pipeline: path = self.get_temp_dir() deferred_metadata = pipeline | 'CreateDeferredMetadata' >> beam.Create( [test_metadata.COMPLETE_METADATA]) metadata = beam_metadata_io.BeamDatasetMetadata( test_metadata.INCOMPLETE_METADATA, deferred_metadata) _ = metadata | beam_metadata_io.WriteMetadata(path, pipeline) # Load from disk and check that it is as expected. metadata = metadata_io.read_metadata(path) self.assertEqual(metadata, test_metadata.COMPLETE_METADATA)