Python apache_beam.Map() Examples

The following are 30 code examples of apache_beam.Map(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module apache_beam , or try the search function .
Example #1
Source File: executor.py    From tfx with Apache License 2.0 6 votes vote down vote up
def _PrestoToExample(  # pylint: disable=invalid-name
    pipeline: beam.Pipeline,
    exec_properties: Dict[Text, Any],
    split_pattern: Text) -> beam.pvalue.PCollection:
  """Read from Presto and transform to TF examples.

  Args:
    pipeline: beam pipeline.
    exec_properties: A dict of execution properties.
    split_pattern: Split.pattern in Input config, a Presto sql string.

  Returns:
    PCollection of TF examples.
  """
  conn_config = example_gen_pb2.CustomConfig()
  json_format.Parse(exec_properties['custom_config'], conn_config)
  presto_config = presto_config_pb2.PrestoConnConfig()
  conn_config.custom_config.Unpack(presto_config)

  client = _deserialize_conn_config(presto_config)
  return (pipeline
          | 'Query' >> beam.Create([split_pattern])
          | 'QueryTable' >> beam.ParDo(_ReadPrestoDoFn(client))
          | 'ToTFExample' >> beam.Map(_row_to_example)) 
Example #2
Source File: executor.py    From tfx with Apache License 2.0 6 votes vote down vote up
def _DecodeInputs(pcoll: beam.pvalue.PCollection,
                    decode_fn: Any) -> beam.pvalue.PCollection:
    """Decodes the given PCollection while handling KV data.

    Args:
      pcoll: PCollection of data.
      decode_fn: Function used to decode data.

    Returns:
      PCollection of decoded data.
    """

    def decode_example(kv: Tuple[Optional[bytes], bytes]) -> Dict[Text, Any]:  # pylint: disable=invalid-name
      """Decodes a single example."""
      (key, value) = kv
      result = decode_fn(value)
      if _TRANSFORM_INTERNAL_FEATURE_FOR_KEY in result:
        raise ValueError('"{}" is a reserved feature name, '
                         'it should not be present in the dataset.'.format(
                             _TRANSFORM_INTERNAL_FEATURE_FOR_KEY))
      result[_TRANSFORM_INTERNAL_FEATURE_FOR_KEY] = key
      return result

    return pcoll | 'ApplyDecodeFn' >> beam.Map(decode_example) 
Example #3
Source File: avro_executor.py    From tfx with Apache License 2.0 6 votes vote down vote up
def _AvroToExample(  # pylint: disable=invalid-name
    pipeline: beam.Pipeline, exec_properties: Dict[Text, Any],
    split_pattern: Text) -> beam.pvalue.PCollection:
  """Read Avro files and transform to TF examples.

  Note that each input split will be transformed by this function separately.

  Args:
    pipeline: beam pipeline.
    exec_properties: A dict of execution properties.
      - input_base: input dir that contains Avro data.
    split_pattern: Split.pattern in Input config, glob relative file pattern
      that maps to input files with root directory given by input_base.

  Returns:
    PCollection of TF examples.
  """
  input_base_uri = exec_properties[utils.INPUT_BASE_KEY]
  avro_pattern = os.path.join(input_base_uri, split_pattern)
  logging.info('Processing input avro data %s to TFExample.', avro_pattern)

  return (pipeline
          | 'ReadFromAvro' >> beam.io.ReadFromAvro(avro_pattern)
          | 'ToTFExample' >> beam.Map(utils.dict_to_example)) 
Example #4
Source File: executor.py    From tfx with Apache License 2.0 6 votes vote down vote up
def _ReadExamples(
      pipeline: beam.Pipeline, dataset: _Dataset,
      input_dataset_metadata: dataset_metadata.DatasetMetadata
  ) -> beam.pvalue.PCollection:
    """Reads examples from the given `dataset`.

    Args:
      pipeline: beam pipeline.
      dataset: A `_Dataset` object that represents the data to read.
      input_dataset_metadata: A `dataset_metadata.DatasetMetadata`. Not used.

    Returns:
      A PCollection containing KV pairs of bytes.
    """
    del input_dataset_metadata
    assert dataset.file_format == labels.FORMAT_TFRECORD, dataset.file_format

    return (
        pipeline
        | 'Read' >> beam.io.ReadFromTFRecord(
            dataset.file_pattern,
            coder=beam.coders.BytesCoder(),
            # TODO(b/114938612): Eventually remove this override.
            validate=False)
        | 'AddKey' >> beam.Map(lambda x: (None, x))) 
Example #5
Source File: executor_test.py    From tfx with Apache License 2.0 6 votes vote down vote up
def testImportExample(self):
    with beam.Pipeline() as pipeline:
      examples = (
          pipeline
          | 'ToSerializedRecord' >> executor._ImportSerializedRecord(
              exec_properties={utils.INPUT_BASE_KEY: self._input_data_dir},
              split_pattern='tfrecord/*')
          | 'ToTFExample' >> beam.Map(tf.train.Example.FromString))

      def check_result(got):
        # We use Python assertion here to avoid Beam serialization error in
        # pickling tf.test.TestCase.
        assert (15000 == len(got)), 'Unexpected example count'
        assert (18 == len(got[0].features.feature)), 'Example not match'

      util.assert_that(examples, check_result) 
Example #6
Source File: preprocess.py    From cloudml-samples with Apache License 2.0 6 votes vote down vote up
def configure_pipeline(p, opt):
  """Specify PCollection and transformations in pipeline."""
  read_input_source = beam.io.ReadFromText(
      opt.input_path, strip_trailing_newlines=True)
  read_label_source = beam.io.ReadFromText(
      opt.input_dict, strip_trailing_newlines=True)
  labels = (p | 'Read dictionary' >> read_label_source)
  _ = (p
       | 'Read input' >> read_input_source
       | 'Parse input' >> beam.Map(lambda line: csv.reader([line]).next())
       | 'Extract label ids' >> beam.ParDo(ExtractLabelIdsDoFn(),
                                           beam.pvalue.AsIter(labels))
       | 'Read and convert to JPEG'
       >> beam.ParDo(ReadImageAndConvertToJpegDoFn())
       | 'Embed and make TFExample' >> beam.ParDo(TFExampleFromImageDoFn())
       # TODO(b/35133536): Get rid of this Map and instead use
       # coder=beam.coders.ProtoCoder(tf.train.Example) in WriteToTFRecord
       # below.
       | 'SerializeToString' >> beam.Map(lambda x: x.SerializeToString())
       | 'Save to disk'
       >> beam.io.WriteToTFRecord(opt.output_path,
                                  file_name_suffix='.tfrecord.gz')) 
Example #7
Source File: run_pipeline_lib.py    From healthcare-deid with Apache License 2.0 6 votes vote down vote up
def compare_bq_row(row, types_to_ignore):
  """Compare the findings in the given BigQuery row.

  Args:
    row: BQ row: Map containing (findings_record_id, findings_xml, golden_xml).
    types_to_ignore: List of strings representing types that should be excluded
      from the analysis.
  Returns:
    (IndividualResult, IndividualResult), where the first is for strict entity
    matching and the second is for binary token matching.
  Raises:
    Exception: If golden_xml doesn't exist.
  """
  findings, note_text = get_findings_from_text(row['findings_xml'],
                                               types_to_ignore)
  if 'golden_xml' not in row or row['golden_xml'] is None:
    raise Exception(
        'No golden found for record %s.' % row['findings_record_id'])
  golden_findings, golden_note_text = get_findings_from_text(row['golden_xml'],
                                                             types_to_ignore)
  record_id = row['findings_record_id']

  return compare_findings(findings, golden_findings, record_id, note_text,
                          golden_note_text) 
Example #8
Source File: preprocess.py    From cloudml-samples with Apache License 2.0 6 votes vote down vote up
def expand(self, pvalue):
    if self._handle.endswith('.csv'):
      # The input is CSV file(s).
      schema = reddit.make_input_schema(mode=self._mode)
      csv_coder = reddit.make_csv_coder(schema, mode=self._mode)
      return (pvalue.pipeline
              | 'ReadFromText' >> beam.io.ReadFromText(
                  self._handle,
                  # TODO(b/35653662): Obviate the need for setting this.
                  coder=beam.coders.BytesCoder())
              | 'ParseCSV' >> beam.Map(csv_coder.decode))
    else:
      # The input is BigQuery table name(s).
      query = reddit.make_standard_sql(self._handle, mode=self._mode)
      return (pvalue.pipeline
              | 'ReadFromBigQuery' >> beam.io.Read(
                  beam.io.BigQuerySource(query=query, use_standard_sql=True)))


# TODO: Perhaps use Reshuffle (https://issues.apache.org/jira/browse/BEAM-1872)? 
Example #9
Source File: streaming_beam.py    From python-docs-samples with Apache License 2.0 6 votes vote down vote up
def run(args, input_subscription, output_table, window_interval):
    """Build and run the pipeline."""
    options = PipelineOptions(args, save_main_session=True, streaming=True)

    with beam.Pipeline(options=options) as pipeline:

        # Read the messages from PubSub and process them.
        messages = (
            pipeline
            | 'Read from Pub/Sub' >> beam.io.ReadFromPubSub(
                subscription=input_subscription).with_output_types(bytes)
            | 'UTF-8 bytes to string' >> beam.Map(lambda msg: msg.decode('utf-8'))
            | 'Parse JSON messages' >> beam.Map(parse_json_message)
            | 'Fixed-size windows' >> beam.WindowInto(
                window.FixedWindows(int(window_interval), 0))
            | 'Add URL keys' >> beam.Map(lambda msg: (msg['url'], msg))
            | 'Group by URLs' >> beam.GroupByKey()
            | 'Get statistics' >> beam.Map(get_statistics))

        # Output the results into BigQuery table.
        _ = messages | 'Write to Big Query' >> beam.io.WriteToBigQuery(
            output_table, schema=SCHEMA) 
Example #10
Source File: batch_util.py    From data-validation with Apache License 2.0 6 votes vote down vote up
def BatchExamplesToArrowRecordBatches(
    examples: beam.pvalue.PCollection,
    desired_batch_size: Optional[int] = constants
    .DEFAULT_DESIRED_INPUT_BATCH_SIZE
) -> beam.pvalue.PCollection:
  """Batches example dicts into Arrow record batches.

  Args:
    examples: A PCollection of example dicts.
    desired_batch_size: Batch size. The output Arrow record batches will have as
      many rows as the `desired_batch_size`.

  Returns:
    A PCollection of Arrow record batches.
  """
  return (
      examples
      | "BatchBeamExamples" >> beam.BatchElements(
          **batch_util.GetBatchElementsKwargs(desired_batch_size))
      | "DecodeExamplesToRecordBatch" >> beam.Map(
          # pylint: disable=unnecessary-lambda
          lambda x: decoded_examples_to_arrow.DecodedExamplesToRecordBatch(x)))
          # pylint: enable=unnecessary-lambda 
Example #11
Source File: PubSubToGCS.py    From python-docs-samples with Apache License 2.0 6 votes vote down vote up
def expand(self, pcoll):
        return (
            pcoll
            # Assigns window info to each Pub/Sub message based on its
            # publish timestamp.
            | "Window into Fixed Intervals"
            >> beam.WindowInto(window.FixedWindows(self.window_size))
            | "Add timestamps to messages" >> beam.ParDo(AddTimestamps())
            # Use a dummy key to group the elements in the same window.
            # Note that all the elements in one window must fit into memory
            # for this. If the windowed elements do not fit into memory,
            # please consider using `beam.util.BatchElements`.
            # https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.util.html#apache_beam.transforms.util.BatchElements
            | "Add Dummy Key" >> beam.Map(lambda elem: (None, elem))
            | "Groupby" >> beam.GroupByKey()
            | "Abandon Dummy Key" >> beam.MapTuple(lambda _, val: val)
        ) 
Example #12
Source File: impl.py    From transform with Apache License 2.0 6 votes vote down vote up
def expand(self, pipeline):

    def _make_and_increment_counters(unused_element, analyzer_counter,
                                     mapper_counter):
      del unused_element
      for counter_prefix, counter in (('tft_analyzer_{}', analyzer_counter),
                                      ('tft_mapper_{}', mapper_counter)):
        for name, count in counter.items():
          beam.metrics.Metrics.counter(beam_common.METRICS_NAMESPACE,
                                       counter_prefix.format(name)).inc(count)

    _ = (
        pipeline
        | 'CreateSoleAPIUse' >> beam.Create([None])
        | 'CountAPIUse' >>
        beam.Map(_make_and_increment_counters, self._analyzer_use_counter,
                 self._mapper_use_counter)) 
Example #13
Source File: _util.py    From pydatalab with Apache License 2.0 6 votes vote down vote up
def get_sources_from_dataset(p, dataset, mode):
  """get pcollection from dataset."""

  import apache_beam as beam
  import csv
  from google.datalab.ml import CsvDataSet, BigQueryDataSet

  check_dataset(dataset, mode)
  if type(dataset) is CsvDataSet:
    source_list = []
    for ii, input_path in enumerate(dataset.files):
      source_list.append(p | 'Read from Csv %d (%s)' % (ii, mode) >>
                         beam.io.ReadFromText(input_path, strip_trailing_newlines=True))
    return (source_list |
            'Flatten Sources (%s)' % mode >>
            beam.Flatten() |
            'Create Dict from Csv (%s)' % mode >>
            beam.Map(lambda line: csv.DictReader([line], fieldnames=['image_url',
                                                                     'label']).next()))
  elif type(dataset) is BigQueryDataSet:
    bq_source = (beam.io.BigQuerySource(table=dataset.table) if dataset.table is not None else
                 beam.io.BigQuerySource(query=dataset.query))
    return p | 'Read source from BigQuery (%s)' % mode >> beam.io.Read(bq_source)
  else:
    raise ValueError('Invalid DataSet. Expect CsvDataSet or BigQueryDataSet') 
Example #14
Source File: datagen_beam.py    From magenta with Apache License 2.0 6 votes vote down vote up
def create_glyphazzn_dataset(filepattern, output_path):
  """Creates a glyphazzn dataset, from raw Parquetio to TFRecords."""
  def pipeline(root):
    """Pipeline for creating glyphazzn dataset."""
    attrs = ['uni', 'width', 'vwidth', 'sfd', 'id', 'binary_fp']

    examples = root | 'Read' >> beam.io.parquetio.ReadFromParquet(
        file_pattern=filepattern, columns=attrs)

    examples = examples | 'FilterBadIcons' >> beam.Filter(_is_valid_glyph)
    examples = examples | 'ConvertToPath' >> beam.Map(_convert_to_path)
    examples = examples | 'FilterBadPathLenghts' >> beam.Filter(_is_valid_path)
    examples = examples | 'ProcessAndConvert' >> beam.Map(_create_example)
    (examples | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
        output_path, num_shards=90))
  return pipeline 
Example #15
Source File: gcs_to_bigquery_lib.py    From healthcare-deid with Apache License 2.0 5 votes vote down vote up
def run_pipeline(input_pattern, output_table, pipeline_args):
  """Read the records from GCS and write them to BigQuery."""
  p = beam.Pipeline(options=PipelineOptions(pipeline_args))
  _ = (p |
       'match_files' >> beam.Create(f2pn.match_files(input_pattern)) |
       'to_records' >> beam.FlatMap(f2pn.map_file_to_records) |
       'parse_physionet_record' >> beam.Map(f2pn.parse_physionet_record) |
       'write' >> beam.io.Write(beam.io.BigQuerySink(
           output_table,
           schema='patient_id:INTEGER, record_number:INTEGER, note:STRING',
           write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE)))
  result = p.run().wait_until_finish()
  logging.info('GCS to BigQuery result: %s', result) 
Example #16
Source File: tft_benchmark_base.py    From tfx with Apache License 2.0 5 votes vote down vote up
def expand(self, pipeline):
    # TODO(b/147620802): Consider making this (and other parameters)
    # configurable to test more variants (e.g. with and without deep-copy
    # optimisation, with and without cache, etc).
    with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
      converter = tft.coders.ExampleProtoCoder(
          self._tf_metadata_schema, serialized=False)
      raw_data = (
          pipeline
          | "ReadDataset" >> beam.Create(self._dataset.read_raw_dataset())
          | "Decode" >> beam.Map(converter.decode))
      transform_fn, output_metadata = (
          (raw_data, self._transform_input_dataset_metadata)
          | "AnalyzeDataset" >> tft_beam.AnalyzeDataset(self._preprocessing_fn))

      if self._generate_dataset:
        _ = transform_fn | "CopySavedModel" >> _CopySavedModel(
            dest_path=self._dataset.tft_saved_model_path())

      (transformed_dataset, transformed_metadata) = (
          ((raw_data, self._transform_input_dataset_metadata),
           (transform_fn, output_metadata))
          | "TransformDataset" >> tft_beam.TransformDataset())
      return transformed_dataset, transformed_metadata


# Tuple for variables common to all benchmarks. 
Example #17
Source File: physionet_to_mae_lib.py    From healthcare-deid with Apache License 2.0 5 votes vote down vote up
def run_pipeline(input_pattern, output_dir, mae_task_name, project,
                 pipeline_args):
  """Read the physionet records from GCS and write them out as MAE."""
  p = beam.Pipeline(options=PipelineOptions(pipeline_args))
  _ = (p |
       'match_files' >> beam.Create(f2pn.match_files(input_pattern)) |
       'to_records' >> beam.FlatMap(f2pn.map_phi_to_findings) |
       'generate_mae' >> beam.Map(mae.generate_mae, mae_task_name, {},
                                  ['patient_id', 'record_number']) |
       'write_mae' >> beam.Map(write_mae, project, output_dir)
      )
  result = p.run().wait_until_finish()
  logging.info('GCS to BigQuery result: %s', result) 
Example #18
Source File: bigquery_to_gcs_lib.py    From healthcare-deid with Apache License 2.0 5 votes vote down vote up
def run_pipeline(input_query, output_file, pipeline_args):
  p = beam.Pipeline(options=PipelineOptions(pipeline_args))
  _ = (p
       | 'read' >> beam.io.Read(beam.io.BigQuerySource(query=input_query))
       | 'to_physionet' >> beam.Map(map_to_physionet_record)
       | 'write' >> beam.io.WriteToText(output_file))
  result = p.run().wait_until_finish()

  logging.info('BigQuery to GCS result: %s', result) 
Example #19
Source File: parquet_executor.py    From tfx with Apache License 2.0 5 votes vote down vote up
def _ParquetToExample(  # pylint: disable=invalid-name
    pipeline: beam.Pipeline, exec_properties: Dict[Text, Any],
    split_pattern: Text) -> beam.pvalue.PCollection:
  """Read Parquet files and transform to TF examples.

  Note that each input split will be transformed by this function separately.

  Args:
    pipeline: beam pipeline.
    exec_properties: A dict of execution properties.
      - input_base: input dir that contains Parquet data.
    split_pattern: Split.pattern in Input config, glob relative file pattern
      that maps to input files with root directory given by input_base.

  Returns:
    PCollection of TF examples.
  """
  input_base_uri = exec_properties[utils.INPUT_BASE_KEY]
  parquet_pattern = os.path.join(input_base_uri, split_pattern)
  logging.info('Processing input parquet data %s to TFExample.',
               parquet_pattern)

  return (pipeline
          # TODO(jyzhao): support per column read by input_config.
          | 'ReadFromParquet' >> beam.io.ReadFromParquet(parquet_pattern)
          | 'ToTFExample' >> beam.Map(utils.dict_to_example)) 
Example #20
Source File: impl.py    From transform with Apache License 2.0 5 votes vote down vote up
def expand(self, inputs):
    pcoll, = inputs

    def extract_keys(input_dict, keys):
      return (tuple(input_dict[k] for k in keys)
              if isinstance(keys, tuple) else input_dict[keys])

    return pcoll | 'ExtractKeys' >> beam.Map(extract_keys, keys=self._keys) 
Example #21
Source File: impl.py    From transform with Apache License 2.0 5 votes vote down vote up
def expand(self, pbegin):
    # TODO(b/151921205): we have to do an identity map for unmodified
    # PCollections below because otherwise we get an error from beam.
    identity_map = 'Identity' >> beam.Map(lambda x: x)
    if self._dataset_key.is_flattened_dataset_key():
      if self._flat_pcollection:
        return self._flat_pcollection | identity_map
      else:
        return (
            list(self._pcollection_dict.values())
            | 'FlattenAnalysisInputs' >> beam.Flatten(pipeline=pbegin.pipeline))
    else:
      return self._pcollection_dict[self._dataset_key] | identity_map 
Example #22
Source File: impl.py    From transform with Apache License 2.0 5 votes vote down vote up
def expand(self, inputs):
    pipeline = (inputs[0] if isinstance(inputs, tuple) else inputs).pipeline
    saved_model_dir_pcoll = pipeline | 'CreateSavedModel' >> beam.Create(
        [self._unbound_saved_model_dir])

    if isinstance(inputs, beam.pvalue.PBegin):
      return saved_model_dir_pcoll

    return saved_model_dir_pcoll | 'ReplaceWithConstants' >> beam.Map(
        _replace_tensors_with_constant_values, self._base_temp_dir,
        *[beam.pvalue.AsSingleton(pcoll) for pcoll in inputs]) 
Example #23
Source File: impl.py    From transform with Apache License 2.0 5 votes vote down vote up
def expand(self, inputs):
    pcoll, = inputs
    return pcoll | 'ToTensorBinding' >> beam.Map(_TensorBinding, self._tensor,
                                                 self._is_asset_file) 
Example #24
Source File: analyzer_impls.py    From transform with Apache License 2.0 5 votes vote down vote up
def _MutualInformationTransformMerge(  # pylint: disable=invalid-name
    pcol, use_adjusted_mutual_info, min_diff_from_avg):
  """Computes mutual information for each key using the given accumulators."""
  feature_accumulator_pcol = (
      pcol | 'VocabCountPerLabelPerTokenMerge' >> beam.CombinePerKey(
          _WeightedMeanCombineFn(output_shape=(None,))))

  accumulators_by_feature, global_accumulator = (
      feature_accumulator_pcol
      | 'ExtractSentinels' >> beam.FlatMap(_extract_sentinels).with_outputs(
          'feature', 'global'))
  if min_diff_from_avg is None:
    min_diff_from_avg = (
        global_accumulator | 'AutoMinDiffFromAvg' >>
        beam.Map(lambda acc: analyzers.calculate_recommended_min_diff_from_avg(  # pylint: disable=g-long-lambda
            acc.count * acc.weight)))
    min_diff_from_avg = beam.pvalue.AsSingleton(min_diff_from_avg)

  def _extract_merged_values(term, results):
    """Returns the key and tuple of (mutual information, frequency)."""
    # Ignore the second value, which is the Expected Mutual Info.
    (mi, _, frequency) = results
    return term, (mi, frequency)

  return (accumulators_by_feature
          | 'CalculateMutualInformationPerToken' >> beam.Map(
              _calculate_mutual_information_for_feature_value,
              beam.pvalue.AsSingleton(global_accumulator),
              use_adjusted_mutual_info=use_adjusted_mutual_info,
              min_diff_from_avg=min_diff_from_avg)
          | beam.MapTuple(_extract_merged_values)) 
Example #25
Source File: common.py    From transform with Apache License 2.0 5 votes vote down vote up
def expand(self, pcoll):
    _ = (
        pcoll.pipeline
        | 'CreateSole' >> beam.Create([None])
        | 'Count' >> beam.Map(self._make_and_increment_counter))
    return pcoll 
Example #26
Source File: analyzer_impls.py    From transform with Apache License 2.0 5 votes vote down vote up
def expand(self, inputs):
    pcoll, = inputs
    return pcoll | 'AddKey' >> beam.Map(lambda value: (self._key, value)) 
Example #27
Source File: analyzer_impls.py    From transform with Apache License 2.0 5 votes vote down vote up
def expand(self, pbegin):
    del pbegin  # unused

    return (self._cache_pcoll
            | 'Decode' >> beam.Map(self._coder.decode_cache)
            | 'Count' >> common.IncrementCounter('cache_entries_decoded')) 
Example #28
Source File: analyzer_impls.py    From transform with Apache License 2.0 5 votes vote down vote up
def expand(self, inputs):
    pcoll, = inputs

    return (pcoll
            | 'Encode' >> beam.Map(self._coder.encode_cache)
            | 'Count' >> common.IncrementCounter('cache_entries_encoded')) 
Example #29
Source File: sample_mapping_table.py    From gcp-variant-transforms with Apache License 2.0 5 votes vote down vote up
def expand(self, pcoll):
    return (pcoll
            | 'ExtractIdNameTuples' >> beam.Map(self._extract_id_name)
            | 'CombineToDict' >> beam.combiners.ToDict()) 
Example #30
Source File: deep_copy_test.py    From transform with Apache License 2.0 5 votes vote down vote up
def testBasicDeepCopy(self):
    with beam.Pipeline() as p:
      grouped = (p
                 | beam.Create([(1, 'a'), (2, 'b'), (3, 'c')])
                 | beam.Map(
                     lambda x: DeepCopyTest._CountingIdentityFn(
                         'PreGroup', x))
                 | beam.GroupByKey())
      modified = (
          grouped
          |
          'Add1' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add1'))
          |
          'Add2' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add2')))
      copied = deep_copy.deep_copy(modified)

      # pylint: disable=expression-not-assigned
      modified | 'Add3' >> beam.Map(
          DeepCopyTest._MakeAdd1CountingIdentityFn('Add3'))
      # pylint: enable=expression-not-assigned

      # Check labels.
      self.assertEqual(copied.producer.full_label, 'Add2.Copy')
      self.assertEqual(copied.producer.inputs[0].producer.full_label,
                       'Add1.Copy')

      # Check that deep copy was performed.
      self.assertIsNot(copied.producer.inputs[0], modified.producer.inputs[0])

      # Check that copy stops at materialization boundary.
      self.assertIs(copied.producer.inputs[0].producer.inputs[0],
                    modified.producer.inputs[0].producer.inputs[0])

    # Check counts of processed items.
    self.assertEqual(DeepCopyTest._counts['PreGroup'], 3)
    self.assertEqual(DeepCopyTest._counts['Add1'], 6)
    self.assertEqual(DeepCopyTest._counts['Add2'], 6)
    self.assertEqual(DeepCopyTest._counts['Add3'], 3)