Python apache_beam.Pipeline() Examples

The following are 30 code examples of apache_beam.Pipeline(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module apache_beam , or try the search function .
Example #1
Source File: streaming_beam.py    From python-docs-samples with Apache License 2.0 6 votes vote down vote up
def run(args, input_subscription, output_table, window_interval):
    """Build and run the pipeline."""
    options = PipelineOptions(args, save_main_session=True, streaming=True)

    with beam.Pipeline(options=options) as pipeline:

        # Read the messages from PubSub and process them.
        messages = (
            pipeline
            | 'Read from Pub/Sub' >> beam.io.ReadFromPubSub(
                subscription=input_subscription).with_output_types(bytes)
            | 'UTF-8 bytes to string' >> beam.Map(lambda msg: msg.decode('utf-8'))
            | 'Parse JSON messages' >> beam.Map(parse_json_message)
            | 'Fixed-size windows' >> beam.WindowInto(
                window.FixedWindows(int(window_interval), 0))
            | 'Add URL keys' >> beam.Map(lambda msg: (msg['url'], msg))
            | 'Group by URLs' >> beam.GroupByKey()
            | 'Get statistics' >> beam.Map(get_statistics))

        # Output the results into BigQuery table.
        _ = messages | 'Write to Big Query' >> beam.io.WriteToBigQuery(
            output_table, schema=SCHEMA) 
Example #2
Source File: vcf_to_bq.py    From gcp-variant-transforms with Apache License 2.0 6 votes vote down vote up
def _read_variants(all_patterns,  # type: List[str]
                   pipeline,  # type: beam.Pipeline
                   known_args,  # type: argparse.Namespace
                   pipeline_mode,  # type: int
                   pre_infer_headers=False,  # type: bool
                   keep_raw_sample_names=False
                  ):
  # type: (...) -> pvalue.PCollection
  """Helper method for returning a PCollection of Variants from VCFs."""
  representative_header_lines = None
  if known_args.representative_header_file:
    representative_header_lines = vcf_header_parser.get_metadata_header_lines(
        known_args.representative_header_file)
  return pipeline_common.read_variants(
      pipeline,
      all_patterns,
      pipeline_mode,
      known_args.allow_malformed_records,
      representative_header_lines,
      pre_infer_headers=pre_infer_headers,
      sample_name_encoding=(
          SampleNameEncoding.NONE if keep_raw_sample_names else
          SampleNameEncoding[known_args.sample_name_encoding]),
      use_1_based_coordinate=known_args.use_1_based_coordinate) 
Example #3
Source File: pipeline_common.py    From gcp-variant-transforms with Apache License 2.0 6 votes vote down vote up
def read_headers(
    pipeline,  #type: beam.Pipeline
    pipeline_mode,  #type: int
    all_patterns  #type: List[str]
    ):
  # type: (...) -> pvalue.PCollection
  """Creates an initial PCollection by reading the VCF file headers."""
  compression_type = get_compression_type(all_patterns)
  if pipeline_mode == PipelineModes.LARGE:
    headers = (pipeline
               | beam.Create(all_patterns)
               | vcf_header_io.ReadAllVcfHeaders(
                   compression_type=compression_type))
  else:
    headers = pipeline | vcf_header_io.ReadVcfHeaders(
        all_patterns[0],
        compression_type=compression_type)

  return headers 
Example #4
Source File: beam_utils.py    From lingvo with Apache License 2.0 6 votes vote down vote up
def GetPipelineRoot(options=None):
  """Return the root of the beam pipeline.

  Typical usage looks like:

    with GetPipelineRoot() as root:
      _ = (root | beam.ParDo() | ...)

  In this example, the pipeline is automatically executed when the context is
  exited, though one can manually run the pipeline built from the root object as
  well.

  Args:
    options: A beam.options.pipeline_options.PipelineOptions object.

  Returns:
    A beam.Pipeline root object.
  """
  return beam.Pipeline(options=options) 
Example #5
Source File: stats_api_test.py    From data-validation with Apache License 2.0 6 votes vote down vote up
def test_stats_pipeline_with_zero_examples(self):
    expected_result = text_format.Parse(
        """
        datasets {
          num_examples: 0
        }
        """, statistics_pb2.DatasetFeatureStatisticsList())
    with beam.Pipeline() as p:
      options = stats_options.StatsOptions(
          num_top_values=1,
          num_rank_histogram_buckets=1,
          num_values_histogram_buckets=2,
          num_histogram_buckets=1,
          num_quantiles_histogram_buckets=1,
          epsilon=0.001)
      result = (p | beam.Create([]) | stats_api.GenerateStatistics(options))
      util.assert_that(
          result,
          test_util.make_dataset_feature_stats_list_proto_equal_fn(
              self, expected_result)) 
Example #6
Source File: stats_api_test.py    From data-validation with Apache License 2.0 6 votes vote down vote up
def test_stats_pipeline_with_sample_rate(self):
    record_batches = [
        pa.RecordBatch.from_arrays(
            [pa.array([np.linspace(1, 3000, 3000, dtype=np.int32)])], ['c']),
    ]

    with beam.Pipeline() as p:
      options = stats_options.StatsOptions(
          sample_rate=1.0,
          num_top_values=2,
          num_rank_histogram_buckets=2,
          num_values_histogram_buckets=2,
          num_histogram_buckets=2,
          num_quantiles_histogram_buckets=2,
          epsilon=0.001)
      result = (
          p | beam.Create(record_batches)
          | stats_api.GenerateStatistics(options))
      util.assert_that(
          result,
          test_util.make_dataset_feature_stats_list_proto_equal_fn(
              self, self._sampling_test_expected_result)) 
Example #7
Source File: stats_api_test.py    From data-validation with Apache License 2.0 6 votes vote down vote up
def test_write_stats_to_text(self):
    stats = text_format.Parse(
        """
        datasets {
          name: 'x'
          num_examples: 100
        }
        """, statistics_pb2.DatasetFeatureStatisticsList())
    output_path = os.path.join(self._get_temp_dir(), 'stats')
    with beam.Pipeline() as p:
      _ = (p | beam.Create([stats]) | stats_api.WriteStatisticsToText(
          output_path))
    stats_from_file = statistics_pb2.DatasetFeatureStatisticsList()
    serialized_stats = io_util.read_file_to_string(
        output_path, binary_mode=True)
    stats_from_file.ParseFromString(serialized_stats)
    self.assertLen(stats_from_file.datasets, 1)
    test_util.assert_dataset_feature_stats_proto_equal(
        self, stats_from_file.datasets[0], stats.datasets[0]) 
Example #8
Source File: stats_impl_test.py    From data-validation with Apache License 2.0 6 votes vote down vote up
def test_stats_impl(self,
                      record_batches,
                      options,
                      expected_result_proto_text,
                      schema=None):
    expected_result = text_format.Parse(
        expected_result_proto_text,
        statistics_pb2.DatasetFeatureStatisticsList())
    if schema is not None:
      options.schema = schema
    with beam.Pipeline() as p:
      result = (
          p | beam.Create(record_batches, reshuffle=False)
          | stats_impl.GenerateStatisticsImpl(options))
      util.assert_that(
          result,
          test_util.make_dataset_feature_stats_list_proto_equal_fn(
              self, expected_result)) 
Example #9
Source File: csv_decoder_test.py    From data-validation with Apache License 2.0 6 votes vote down vote up
def test_csv_decoder(self,
                       input_lines,
                       expected_result,
                       column_names,
                       delimiter=',',
                       skip_blank_lines=True,
                       schema=None,
                       multivalent_columns=None,
                       secondary_delimiter=None):
    with beam.Pipeline() as p:
      result = (
          p | beam.Create(input_lines, reshuffle=False)
          | csv_decoder.DecodeCSV(
              column_names=column_names,
              delimiter=delimiter,
              skip_blank_lines=skip_blank_lines,
              schema=schema,
              multivalent_columns=multivalent_columns,
              secondary_delimiter=secondary_delimiter))
      util.assert_that(
          result,
          test_util.make_arrow_record_batches_equal_fn(self, expected_result)) 
Example #10
Source File: _local.py    From pydatalab with Apache License 2.0 6 votes vote down vote up
def preprocess(train_dataset, output_dir, eval_dataset, checkpoint):
    """Preprocess data locally."""

    import apache_beam as beam
    from google.datalab.utils import LambdaJob
    from . import _preprocess

    if checkpoint is None:
      checkpoint = _util._DEFAULT_CHECKPOINT_GSURL
    job_id = ('preprocess-image-classification-' +
              datetime.datetime.now().strftime('%y%m%d-%H%M%S'))
    # Project is needed for bigquery data source, even in local run.
    options = {
        'project': _util.default_project(),
    }
    opts = beam.pipeline.PipelineOptions(flags=[], **options)
    p = beam.Pipeline('DirectRunner', options=opts)
    _preprocess.configure_pipeline(p, train_dataset, eval_dataset, checkpoint, output_dir, job_id)
    job = LambdaJob(lambda: p.run().wait_until_finish(), job_id)
    return job 
Example #11
Source File: executor_test.py    From tfx with Apache License 2.0 6 votes vote down vote up
def testPrestoToExample(self):
    with beam.Pipeline() as pipeline:
      examples = (
          pipeline | 'ToTFExample' >> executor._PrestoToExample(
              exec_properties={
                  'input_config':
                      json_format.MessageToJson(
                          example_gen_pb2.Input(),
                          preserving_proto_field_name=True),
                  'custom_config':
                      json_format.MessageToJson(
                          example_gen_pb2.CustomConfig(),
                          preserving_proto_field_name=True)
              },
              split_pattern='SELECT i, f, s FROM `fake`'))

      feature = {}
      feature['i'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[1]))
      feature['f'] = tf.train.Feature(
          float_list=tf.train.FloatList(value=[2.0]))
      feature['s'] = tf.train.Feature(
          bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes('abc')]))
      example_proto = tf.train.Example(
          features=tf.train.Features(feature=feature))
      util.assert_that(examples, util.equal_to([example_proto])) 
Example #12
Source File: datagen_beam.py    From magenta with Apache License 2.0 6 votes vote down vote up
def create_glyphazzn_dataset(filepattern, output_path):
  """Creates a glyphazzn dataset, from raw Parquetio to TFRecords."""
  def pipeline(root):
    """Pipeline for creating glyphazzn dataset."""
    attrs = ['uni', 'width', 'vwidth', 'sfd', 'id', 'binary_fp']

    examples = root | 'Read' >> beam.io.parquetio.ReadFromParquet(
        file_pattern=filepattern, columns=attrs)

    examples = examples | 'FilterBadIcons' >> beam.Filter(_is_valid_glyph)
    examples = examples | 'ConvertToPath' >> beam.Map(_convert_to_path)
    examples = examples | 'FilterBadPathLenghts' >> beam.Filter(_is_valid_path)
    examples = examples | 'ProcessAndConvert' >> beam.Map(_create_example)
    (examples | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
        output_path, num_shards=90))
  return pipeline 
Example #13
Source File: tft_unit.py    From transform with Apache License 2.0 6 votes vote down vote up
def convert_to_tfxio_api_inputs(
      self, legacy_input_data, legacy_input_metadata, label='input_data'):
    """Converts from the legacy TFT API inputs to TFXIO-based inputs.

    Args:
      legacy_input_data: a PCollection of instance dicts.
      legacy_input_metadata: a tft.DatasetMetadata.
      label: label for the PTransform that translates `legacy_input_data` into
        the TFXIO input data. Set to different values if this method is called
        multiple times in a beam Pipeline.
    Returns:
      A tuple of a PCollection of `pyarrow.RecordBatch` and a
      `tensor_adapter.TensorAdapterConfig`. This tuple can be fed directly to
      TFT's `{Analyze,Transform,AnalyzeAndTransform}Dataset` APIs.
    """
    tfxio_impl = _LegacyCompatibilityTFXIO(legacy_input_metadata.schema)
    input_data = (
        legacy_input_data |
        ('LegacyFormatToTfxio[%s]' % label >> tfxio_impl.BeamSource(
            beam_impl.Context.get_desired_batch_size())))
    return input_data, tfxio_impl.TensorAdapterConfig() 
Example #14
Source File: impl.py    From transform with Apache License 2.0 6 votes vote down vote up
def _clear_shared_state_after_barrier(pipeline, input_barrier):
  """Clears any shared state from within a pipeline context.

  This will only be cleared once input_barrier becomes available.

  Args:
    pipeline: A `beam.Pipeline` object.
    input_barrier: A `PCollection` which the pipeline should wait for.

  Returns:
    An empty `PCollection`.
  """
  empty_pcoll = input_barrier | 'MakeCheapBarrier' >> beam.FlatMap(
      lambda x: None)
  return (pipeline
          | 'PrepareToClearSharedKeepAlives' >> beam.Create([None])
          | 'WaitAndClearSharedKeepAlives' >> beam.Map(
              lambda x, empty_side_input: shared.Shared().acquire(lambda: None),
              beam.pvalue.AsIter(empty_pcoll))) 
Example #15
Source File: transform_fn_io_test.py    From transform with Apache License 2.0 6 votes vote down vote up
def testReadTransformFn(self):
    path = self.get_temp_dir()
    # NOTE: we don't need to create or write to the transform_fn directory since
    # ReadTransformFn never inspects this directory.
    transform_fn_dir = os.path.join(
        path, tft.TFTransformOutput.TRANSFORM_FN_DIR)
    transformed_metadata_dir = os.path.join(
        path, tft.TFTransformOutput.TRANSFORMED_METADATA_DIR)
    metadata_io.write_metadata(test_metadata.COMPLETE_METADATA,
                               transformed_metadata_dir)

    with beam.Pipeline() as pipeline:
      saved_model_dir_pcoll, metadata = (
          pipeline | transform_fn_io.ReadTransformFn(path))
      beam_test_util.assert_that(
          saved_model_dir_pcoll,
          beam_test_util.equal_to([transform_fn_dir]),
          label='AssertSavedModelDir')
      # NOTE: metadata is currently read in a non-deferred manner.
      self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) 
Example #16
Source File: transform_fn_io_test.py    From transform with Apache License 2.0 6 votes vote down vote up
def testWriteTransformFnIsIdempotent(self):
    transform_output_dir = os.path.join(self.get_temp_dir(), 'output')

    def mock_write_metadata_expand(unused_self, unused_metadata):
      raise ArithmeticError('Some error')

    with beam.Pipeline() as pipeline:
      # Create an empty directory for the source saved model dir.
      saved_model_dir = os.path.join(self.get_temp_dir(), 'source')
      saved_model_dir_pcoll = (
          pipeline | 'CreateSavedModelDir' >> beam.Create([saved_model_dir]))

      with mock.patch.object(transform_fn_io.beam_metadata_io.WriteMetadata,
                             'expand', mock_write_metadata_expand):
        with self.assertRaisesRegexp(ArithmeticError, 'Some error'):
          _ = ((saved_model_dir_pcoll, object())
               | transform_fn_io.WriteTransformFn(transform_output_dir))

    self.assertFalse(file_io.file_exists(transform_output_dir)) 
Example #17
Source File: _local.py    From pydatalab with Apache License 2.0 6 votes vote down vote up
def batch_predict(dataset, model_dir, output_csv, output_bq_table):
    """Batch predict running locally."""

    import apache_beam as beam
    from google.datalab.utils import LambdaJob
    from . import _predictor

    if output_csv is None and output_bq_table is None:
      raise ValueError('output_csv and output_bq_table cannot both be None.')

    job_id = ('batch-predict-image-classification-' +
              datetime.datetime.now().strftime('%y%m%d-%H%M%S'))

    # Project is needed for bigquery data source, even in local run.
    options = {
        'project': _util.default_project(),
    }
    opts = beam.pipeline.PipelineOptions(flags=[], **options)
    p = beam.Pipeline('DirectRunner', options=opts)
    _predictor.configure_pipeline(p, dataset, model_dir, output_csv, output_bq_table)
    job = LambdaJob(lambda: p.run().wait_until_finish(), job_id)
    return job 
Example #18
Source File: executor.py    From tfx with Apache License 2.0 6 votes vote down vote up
def _PrestoToExample(  # pylint: disable=invalid-name
    pipeline: beam.Pipeline,
    exec_properties: Dict[Text, Any],
    split_pattern: Text) -> beam.pvalue.PCollection:
  """Read from Presto and transform to TF examples.

  Args:
    pipeline: beam pipeline.
    exec_properties: A dict of execution properties.
    split_pattern: Split.pattern in Input config, a Presto sql string.

  Returns:
    PCollection of TF examples.
  """
  conn_config = example_gen_pb2.CustomConfig()
  json_format.Parse(exec_properties['custom_config'], conn_config)
  presto_config = presto_config_pb2.PrestoConnConfig()
  conn_config.custom_config.Unpack(presto_config)

  client = _deserialize_conn_config(presto_config)
  return (pipeline
          | 'Query' >> beam.Create([split_pattern])
          | 'QueryTable' >> beam.ParDo(_ReadPrestoDoFn(client))
          | 'ToTFExample' >> beam.Map(_row_to_example)) 
Example #19
Source File: PubSubToGCS.py    From python-docs-samples with Apache License 2.0 6 votes vote down vote up
def run(input_topic, output_path, window_size=1.0, pipeline_args=None):
    # `save_main_session` is set to true because some DoFn's rely on
    # globally imported modules.
    pipeline_options = PipelineOptions(
        pipeline_args, streaming=True, save_main_session=True
    )

    with beam.Pipeline(options=pipeline_options) as pipeline:
        (
            pipeline
            | "Read PubSub Messages"
            >> beam.io.ReadFromPubSub(topic=input_topic)
            | "Window into" >> GroupWindowsIntoBatches(window_size)
            | "Write to GCS" >> beam.ParDo(WriteBatchesToGCS(output_path))
        ) 
Example #20
Source File: bigquery_to_gcs_lib.py    From healthcare-deid with Apache License 2.0 5 votes vote down vote up
def run_pipeline(input_query, output_file, pipeline_args):
  p = beam.Pipeline(options=PipelineOptions(pipeline_args))
  _ = (p
       | 'read' >> beam.io.Read(beam.io.BigQuerySource(query=input_query))
       | 'to_physionet' >> beam.Map(map_to_physionet_record)
       | 'write' >> beam.io.WriteToText(output_file))
  result = p.run().wait_until_finish()

  logging.info('BigQuery to GCS result: %s', result) 
Example #21
Source File: gcs_to_bigquery_lib.py    From healthcare-deid with Apache License 2.0 5 votes vote down vote up
def run_pipeline(input_pattern, output_table, pipeline_args):
  """Read the records from GCS and write them to BigQuery."""
  p = beam.Pipeline(options=PipelineOptions(pipeline_args))
  _ = (p |
       'match_files' >> beam.Create(f2pn.match_files(input_pattern)) |
       'to_records' >> beam.FlatMap(map_file_to_records) |
       'map_to_bq_inputs' >> beam.Map(map_to_bq_inputs) |
       'write' >> beam.io.Write(beam.io.BigQuerySink(
           output_table,
           schema='patient_id:INTEGER, note:STRING',
           write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE)))
  result = p.run().wait_until_finish()
  logging.info('GCS to BigQuery result: %s', result) 
Example #22
Source File: vcf_to_bq.py    From gcp-variant-transforms with Apache License 2.0 5 votes vote down vote up
def _add_inferred_headers(all_patterns,  # type: List[str]
                          pipeline,  # type: beam.Pipeline
                          known_args,  # type: argparse.Namespace
                          merged_header,  # type: pvalue.PCollection
                          pipeline_mode  # type: int
                         ):
  # type: (...) -> pvalue.PCollection
  annotation_fields_to_infer = (known_args.annotation_fields if
                                known_args.infer_annotation_types else [])
  inferred_headers = (
      _read_variants(all_patterns,
                     pipeline,
                     known_args,
                     pipeline_mode,
                     pre_infer_headers=known_args.infer_headers)
      | 'FilterVariants' >> filter_variants.FilterVariants(
          reference_names=known_args.reference_names)
      | 'InferHeaderFields' >> infer_headers.InferHeaderFields(
          pvalue.AsSingleton(merged_header),
          known_args.allow_incompatible_records,
          known_args.infer_headers,
          annotation_fields_to_infer))
  merged_header = (
      (inferred_headers, merged_header)
      | 'FlattenHeaders' >> beam.Flatten()
      | 'MergeHeadersFromVcfAndVariants' >> merge_headers.MergeHeaders(
          known_args.split_alternate_allele_info_fields,
          known_args.allow_incompatible_records))
  return merged_header 
Example #23
Source File: tft_unit.py    From transform with Apache License 2.0 5 votes vote down vote up
def metrics(self):
      if not self.has_ran:
        raise RuntimeError('Pipeline has to run before accessing its metrics')
      return self._run_result.metrics() 
Example #24
Source File: sentiment_example.py    From transform with Apache License 2.0 5 votes vote down vote up
def read_and_shuffle_data(
    train_neg_filepattern, train_pos_filepattern, test_neg_filepattern,
    test_pos_filepattern, working_dir):
  """Read and shuffle the data and write out as a TFRecord of Example protos.

  Read in the data from the positive and negative examples on disk, shuffle it
  and write it out in TFRecord format.
  transform it using a preprocessing pipeline that removes punctuation,
  tokenizes and maps tokens to int64 values indices.

  Args:
    train_neg_filepattern: Filepattern for training data negative examples
    train_pos_filepattern: Filepattern for training data positive examples
    test_neg_filepattern: Filepattern for test data negative examples
    test_pos_filepattern: Filepattern for test data positive examples
    working_dir: Directory to write shuffled data to
  """
  with beam.Pipeline() as pipeline:
    coder = tft.coders.ExampleProtoCoder(RAW_DATA_METADATA.schema)

    # pylint: disable=no-value-for-parameter
    _ = (
        pipeline
        | 'ReadAndShuffleTrain' >> ReadAndShuffleData(
            (train_neg_filepattern, train_pos_filepattern))
        | 'EncodeTrainData' >> beam.Map(coder.encode)
        | 'WriteTrainData' >> beam.io.WriteToTFRecord(
            os.path.join(working_dir, SHUFFLED_TRAIN_DATA_FILEBASE)))

    _ = (
        pipeline
        | 'ReadAndShuffleTest' >> ReadAndShuffleData(
            (test_neg_filepattern, test_pos_filepattern))
        | 'EncodeTestData' >> beam.Map(coder.encode)
        | 'WriteTestData' >> beam.io.WriteToTFRecord(
            os.path.join(working_dir, SHUFFLED_TEST_DATA_FILEBASE)))
    # pylint: enable=no-value-for-parameter 
Example #25
Source File: staffline_patches_dofn_test.py    From moonlight with Apache License 2.0 5 votes vote down vote up
def testPipeline_corpusImage(self):
    filename = os.path.join(tf.resource_loader.get_data_files_path(),
                            '../../testdata/IMSLP00747-000.png')
    with tempfile.NamedTemporaryFile() as output_examples:
      # Run the pipeline to get the staffline patches.
      with beam.Pipeline() as pipeline:
        dofn = staffline_patches_dofn.StafflinePatchesDoFn(
            PATCH_HEIGHT, PATCH_WIDTH, NUM_STAFFLINES, TIMEOUT_MS,
            MAX_PATCHES_PER_PAGE)
        # pylint: disable=expression-not-assigned
        (pipeline | beam.transforms.Create([filename])
         | beam.transforms.ParDo(dofn) | beam.io.WriteToTFRecord(
             output_examples.name,
             beam.coders.ProtoCoder(tf.train.Example),
             shard_name_template=''))
      # Get the staffline images from a local TensorFlow session.
      extractor = staffline_extractor.StafflinePatchExtractor(
          staffline_extractor.DEFAULT_NUM_SECTIONS, PATCH_HEIGHT, PATCH_WIDTH)
      with tf.Session(graph=extractor.graph):
        expected_patches = [
            tuple(patch.ravel())
            for unused_key, patch in extractor.page_patch_iterator(filename)
        ]
      for example_bytes in tf_record.tf_record_iterator(output_examples.name):
        example = tf.train.Example()
        example.ParseFromString(example_bytes)
        patch_pixels = tuple(
            example.features.feature['features'].float_list.value)
        if patch_pixels not in expected_patches:
          self.fail('Missing patch {}'.format(patch_pixels)) 
Example #26
Source File: pipeline_flags.py    From moonlight with Apache License 2.0 5 votes vote down vote up
def create_pipeline(**kwargs):
  return apache_beam.Pipeline(FLAGS.runner, **kwargs) 
Example #27
Source File: onsets_frames_transcription_create_tfrecords.py    From magenta with Apache License 2.0 5 votes vote down vote up
def main(argv):
  del argv


  flags.mark_flags_as_required(['csv', 'output_directory'])

  tf.io.gfile.makedirs(FLAGS.output_directory)

  with tf.io.gfile.GFile(FLAGS.csv) as f:
    reader = csv.DictReader(f)

    splits = collections.defaultdict(list)
    for row in reader:
      splits[row['split']].append(
          (os.path.join(FLAGS.midi_dir, row['midi_filename']),
           os.path.join(FLAGS.wav_dir, row['audio_filename'])))

  if sorted(splits.keys()) != sorted(FLAGS.expected_splits.split(',')):
    raise ValueError('Got unexpected set of splits: %s' % list(splits.keys()))

  pipeline_options = beam.options.pipeline_options.PipelineOptions(
      FLAGS.pipeline_options)
  with beam.Pipeline(options=pipeline_options) as p:
    for split in splits:
      split_p = p | 'prepare_split_%s' % split >> beam.Create(splits[split])
      split_p |= 'create_examples_%s' % split >> beam.ParDo(
          CreateExampleDoFn(FLAGS.wav_dir, FLAGS.midi_dir, FLAGS.add_wav_glob))
      split_p |= 'write_%s' % split >> beam.io.WriteToTFRecord(
          os.path.join(FLAGS.output_directory, '%s.tfrecord' % split),
          coder=beam.coders.ProtoCoder(tf.train.Example),
          num_shards=FLAGS.num_shards) 
Example #28
Source File: datagen_beam.py    From magenta with Apache License 2.0 5 votes vote down vote up
def main(_):
  pipeline_options = beam.options.pipeline_options.PipelineOptions(
      FLAGS.pipeline_options.split(','))

  pipeline = create_glyphazzn_dataset(
      FLAGS.raw_data_file + '*',
      FLAGS.final_data_file)
  with beam.Pipeline(options=pipeline_options) as root:
    pipeline(root)

  pipeline = get_stats_of_glyphazzn(
      FLAGS.final_data_file + '*',
      FLAGS.final_stats_file)
  with beam.Pipeline(options=pipeline_options) as root:
    pipeline(root) 
Example #29
Source File: datagen_beam.py    From magenta with Apache License 2.0 5 votes vote down vote up
def get_stats_of_glyphazzn(filepattern, output_path):
  """Computes the Mean and Std across examples in glyphazzn dataset."""
  def pipeline(root):
    """Pipeline for computing means/std from dataset."""
    examples = root | 'Read' >> beam.io.tfrecordio.ReadFromTFRecord(filepattern)
    examples = examples | 'Deserialize' >> beam.Map(_decode_tfexample)
    examples = examples | 'GetMeanStdev' >> beam.CombineGlobally(MeanStddev())
    examples = examples | 'MeanStdevToSerializedTFRecord' >> beam.Map(
        _mean_to_example)
    (examples | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
        output_path, coder=beam.coders.ProtoCode(tf.train.Example)))
  return pipeline 
Example #30
Source File: beam_metadata_io_test.py    From transform with Apache License 2.0 5 votes vote down vote up
def testWriteMetadataDeferred(self):
    # Write metadata to disk using WriteMetadata PTransform, combining
    # incomplete metadata with (deferred) complete metadata.
    with beam.Pipeline() as pipeline:
      path = self.get_temp_dir()
      deferred_metadata = pipeline | 'CreateDeferredMetadata' >> beam.Create(
          [test_metadata.COMPLETE_METADATA])
      metadata = beam_metadata_io.BeamDatasetMetadata(
          test_metadata.INCOMPLETE_METADATA, deferred_metadata)
      _ = metadata | beam_metadata_io.WriteMetadata(path, pipeline)

    # Load from disk and check that it is as expected.
    metadata = metadata_io.read_metadata(path)
    self.assertEqual(metadata, test_metadata.COMPLETE_METADATA)