Python apache_beam.GroupByKey() Examples

The following are 30 code examples of apache_beam.GroupByKey(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module apache_beam , or try the search function .
Example #1
Source File: preprocess.py    From professional-services with Apache License 2.0 7 votes vote down vote up
def shuffle(p):
  """Shuffles data from PCollection.

  Args:
    p: PCollection.

  Returns:
    PCollection of shuffled data.
  """

  class _AddRandomKey(beam.DoFn):

    def process(self, element):
      yield random.random(), element

  shuffled_data = (
      p
      | 'PairWithRandom' >> beam.ParDo(_AddRandomKey())
      | 'GroupByRandom' >> beam.GroupByKey()
      | 'DropRandom' >> beam.FlatMap(lambda (k, vs): vs))
  return shuffled_data 
Example #2
Source File: linters.py    From data-linter with Apache License 2.0 6 votes vote down vote up
def _lint(self, examples):
    feature_val_w_counts = (
        examples
        | 'Tuplize' >> beam.FlatMap(
            utils.example_tuplizer(self._counted_features))
        | 'FlattenFeatureVals' >> beam.FlatMap(self._flatten_feature_vals)
        | 'CountFeatureVals' >> beam.combiners.Count.PerElement())

    if hasattr(self, '_count_transformer'):
      feature_val_w_counts |= 'TransformCounts' >> self._count_transformer

    return (
        feature_val_w_counts
        | 'PairValWithCount' >> beam.Map(self._shift_key)
        | 'GroupByFeature' >> beam.GroupByKey()
        | 'ValCountsToDict' >> beam.Map(self._val_counts_as_dict)
        | 'GenResults' >> beam.Map(self._check_feature)
        | 'DropUnwarned' >> beam.Filter(bool)
        | 'AsList' >> beam.combiners.ToList()
        | 'ToResult' >> beam.Map(self._to_result)) 
Example #3
Source File: main.py    From professional-services with Apache License 2.0 6 votes vote down vote up
def get_enriched_events(salesevent: beam.pvalue.PCollection,sideinput_collections: Dict[str,beam.pvalue.PCollection]) \
        -> beam.pvalue.PCollection:
    """Gets enriched events by
        a) Call a transform that combining primary event with corresponding side input values
        b) Group events by dummy key to combine all events in a window into one shard
        c) Discard dummy key

     Args:
        salesevent: Event representing sales transaction
        sideinput_collections: Set of Side Input Collections
    """
    # yapf: disable
    return (salesevent
             | "Enrich event" >> beam.Map(transforms.enrich_event,
                                       AsDict(sideinput_collections["bonuspoints"]),
                                       AsDict(sideinput_collections["discountpct"]),
                                       AsDict(sideinput_collections["category"]))
             | "Group events by dummy Key" >> beam.GroupByKey()
             | "Discard dummy Key" >> beam.Values()
          )
    # yapf: enable 
Example #4
Source File: preprocess.py    From professional-services with Apache License 2.0 6 votes vote down vote up
def shuffle_data(p):
  """Shuffles data from PCollection.

  Args:
    p: PCollection.

  Returns:
    PCollection of shuffled data.
  """

  class _AddRandomKey(beam.DoFn):

    def process(self, element):
      yield (random.random(), element)

  shuffled_data = (
      p
      | 'PairWithRandom' >> beam.ParDo(_AddRandomKey())
      | 'GroupByRandom' >> beam.GroupByKey()
      | 'DropRandom' >> beam.FlatMap(lambda (k, vs): vs))
  return shuffled_data 
Example #5
Source File: streaming_beam.py    From python-docs-samples with Apache License 2.0 6 votes vote down vote up
def run(args, input_subscription, output_table, window_interval):
    """Build and run the pipeline."""
    options = PipelineOptions(args, save_main_session=True, streaming=True)

    with beam.Pipeline(options=options) as pipeline:

        # Read the messages from PubSub and process them.
        messages = (
            pipeline
            | 'Read from Pub/Sub' >> beam.io.ReadFromPubSub(
                subscription=input_subscription).with_output_types(bytes)
            | 'UTF-8 bytes to string' >> beam.Map(lambda msg: msg.decode('utf-8'))
            | 'Parse JSON messages' >> beam.Map(parse_json_message)
            | 'Fixed-size windows' >> beam.WindowInto(
                window.FixedWindows(int(window_interval), 0))
            | 'Add URL keys' >> beam.Map(lambda msg: (msg['url'], msg))
            | 'Group by URLs' >> beam.GroupByKey()
            | 'Get statistics' >> beam.Map(get_statistics))

        # Output the results into BigQuery table.
        _ = messages | 'Write to Big Query' >> beam.io.WriteToBigQuery(
            output_table, schema=SCHEMA) 
Example #6
Source File: PubSubToGCS.py    From python-docs-samples with Apache License 2.0 6 votes vote down vote up
def expand(self, pcoll):
        return (
            pcoll
            # Assigns window info to each Pub/Sub message based on its
            # publish timestamp.
            | "Window into Fixed Intervals"
            >> beam.WindowInto(window.FixedWindows(self.window_size))
            | "Add timestamps to messages" >> beam.ParDo(AddTimestamps())
            # Use a dummy key to group the elements in the same window.
            # Note that all the elements in one window must fit into memory
            # for this. If the windowed elements do not fit into memory,
            # please consider using `beam.util.BatchElements`.
            # https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.util.html#apache_beam.transforms.util.BatchElements
            | "Add Dummy Key" >> beam.Map(lambda elem: (None, elem))
            | "Groupby" >> beam.GroupByKey()
            | "Abandon Dummy Key" >> beam.MapTuple(lambda _, val: val)
        ) 
Example #7
Source File: extract_input_size.py    From gcp-variant-transforms with Apache License 2.0 5 votes vote down vote up
def expand(self, estimates):
    return (estimates
            | 'MapSamplesToValueCount' >> beam.FlatMap(
                self._get_sample_ids)
            | 'GroupAllSamples' >> beam.GroupByKey()) 
Example #8
Source File: jackknife.py    From model-analysis with Apache License 2.0 5 votes vote down vote up
def expand(self, sliced_extracts):

    def partition_fn(_, num_partitions):
      return self._random_state.randint(num_partitions)

    # Partition the data
    # List[PCollection[Tuple[slicer.SliceKeyType, types.Extracts]]]
    partitions = (
        sliced_extracts
        | 'Partition' >> beam.Partition(partition_fn,
                                        self._num_jackknife_samples))

    def add_partition_index(slice_key,
                            accumulator_and_size,
                            partition_index=None):
      accumulator, size = accumulator_and_size
      return slice_key, _PartitionInfo(accumulator, size, partition_index)

    # Within each partition, partially combine per slice key to get accumulators
    # and partition sizes; add partition_id for determinism.
    # List[PCollection[slicer.SliceKeyType, _PartitionInfo]]
    partition_accumulators = []
    for i, partition in enumerate(partitions):
      partition_accumulators.append(
          partition
          | 'CombinePartition[{}]'.format(i) >> beam.CombinePerKey(
              beam.transforms.combiners.SingleInputTupleCombineFn(
                  _AccumulateOnlyCombiner(combiner=self._combiner),
                  beam.transforms.combiners.CountCombineFn()))
          | 'AddPartitionId[{}]'.format(i) >> beam.MapTuple(
              add_partition_index, i))

    # Group partitions for the same slice, compute LOO metrics, and flatten back
    # into per-partition LOO metrics.
    # (slicer.SliceKeyType, Tuple[metric_types.MetricsDict])
    return (partition_accumulators
            | 'FlattenPartitionAccumulators' >> beam.Flatten()
            | 'CollectPerSlicePartitions' >> beam.GroupByKey()
            | 'MakeJackknifeSamples' >> beam.FlatMap(
                _make_jackknife_samples, combiner=self._combiner)) 
Example #9
Source File: preprocess_data.py    From cloudml-examples with Apache License 2.0 5 votes vote down vote up
def variants_to_examples(input_data, samples_metadata, sample_to_example_fn):
  """Converts variants to TensorFlow Example protos.

  Args:
    input_data: variant call dictionary objects with keys from
      DATA_QUERY_REPLACEMENTS
    samples_metadata: metadata dictionary objects with keys from
      METADATA_QUERY_REPLACEMENTS
    sample_to_example_fn: the feature encoder strategy to use to
      convert the source data into TensorFlow Example protos.

  Returns:
    TensorFlow Example protos.
  """
  variant_kvs = input_data | 'BucketVariants' >> beam.Map(
      lambda row: (row[encoder.KEY_COLUMN], row))

  sample_variant_kvs = variant_kvs | 'GroupBySample' >> beam.GroupByKey()

  examples = (
      sample_variant_kvs
      | 'SamplesToExamples' >> beam.Map(
          lambda (key, vals), samples_metadata: sample_to_example_fn(
              key, vals, samples_metadata),
          beam.pvalue.AsSingleton(samples_metadata)))

  return examples 
Example #10
Source File: create_data.py    From conversational-datasets with Apache License 2.0 5 votes vote down vote up
def _shuffle(pcollection):
    """Shuffles the input pcollection."""
    pcollection |= "add random key" >> beam.Map(
        lambda value: (uuid.uuid4(), value))
    pcollection |= "group by key" >> beam.GroupByKey()
    pcollection |= "get shuffled values" >> beam.FlatMap(lambda t: t[1])
    return pcollection 
Example #11
Source File: create_data.py    From conversational-datasets with Apache License 2.0 5 votes vote down vote up
def _shuffle_examples(examples):
    examples |= "add random key" >> beam.Map(
        lambda example: (uuid.uuid4(), example)
    )
    examples |= "group by key" >> beam.GroupByKey()
    examples |= "get shuffled values" >> beam.FlatMap(lambda t: t[1])
    return examples 
Example #12
Source File: create_data.py    From conversational-datasets with Apache License 2.0 5 votes vote down vote up
def _shuffle_examples(examples):
    examples |= ("add random key" >> beam.Map(
        lambda example: (uuid.uuid4(), example)))
    examples |= ("group by key" >> beam.GroupByKey())
    examples |= ("get shuffled values" >> beam.FlatMap(lambda t: t[1]))
    return examples 
Example #13
Source File: preprocess.py    From professional-services with Apache License 2.0 5 votes vote down vote up
def shuffle(p):
    """Shuffles the given pCollection."""

    return (p
            | 'PairWithRandom' >> beam.Map(lambda x: (random.random(), x))
            | 'GroupByRandom' >> beam.GroupByKey()
            | 'DropRandom' >> beam.FlatMap(lambda x: x[1]))


# pylint: disable=expression-not-assigned
# pylint: disable=no-value-for-parameter 
Example #14
Source File: preprocess.py    From professional-services with Apache License 2.0 5 votes vote down vote up
def Shuffle(p):
  """Shuffles the given pCollection."""
  return (p
          | "PairWithRandom" >> beam.Map(lambda x: (np.random.random(), x))
          | "GroupByRandom" >> beam.GroupByKey()
          | "DropRandom" >> beam.FlatMap(lambda x: x[1])) 
Example #15
Source File: deep_copy_test.py    From transform with Apache License 2.0 5 votes vote down vote up
def testMultipleCopies(self):
    with beam.Pipeline() as p:
      grouped = (p
                 | beam.Create([(1, 'a'), (2, 'b'), (3, 'c')])
                 | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn(
                     'PreGroup', x))
                 | beam.GroupByKey())
      modified = (
          grouped
          |
          'Add1' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add1'))
          |
          'Add2' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add2')))

      num_copies = 6

      first_copy = deep_copy.deep_copy(modified)
      self.assertEqual(first_copy.producer.full_label, 'Add2.Copy')
      self.assertEqual(first_copy.producer.inputs[0].producer.full_label,
                       'Add1.Copy')

      for i in range(num_copies - 1):
        copied = deep_copy.deep_copy(modified)
        self.assertEqual(copied.producer.full_label, 'Add2.Copy%d' % i)
        self.assertEqual(copied.producer.inputs[0].producer.full_label,
                         'Add1.Copy%d' % i)

    self.assertEqual(DeepCopyTest._counts['PreGroup'], 3)
    self.assertEqual(DeepCopyTest._counts['Add1'], 3 * (num_copies + 1))
    self.assertEqual(DeepCopyTest._counts['Add2'], 3 * (num_copies + 1)) 
Example #16
Source File: merge_variants.py    From gcp-variant-transforms with Apache License 2.0 5 votes vote down vote up
def expand(self, pcoll):
    return (pcoll
            | 'MapVariantsByKey' >> beam.FlatMap(self._map_by_variant_keys)
            | 'GroupVariantsByKey' >> beam.GroupByKey()
            | 'MergeVariantsByKey' >> beam.FlatMap(self._merge_variants_by_key)) 
Example #17
Source File: limit_write.py    From gcp-variant-transforms with Apache License 2.0 5 votes vote down vote up
def expand(self, pcoll):
    return (pcoll
            | beam.ParDo(_RoundRobinKeyFn(self._count))
            | beam.GroupByKey()
            | beam.FlatMap(lambda kv: kv[1])) 
Example #18
Source File: preprocess.py    From cloudml-samples with Apache License 2.0 5 votes vote down vote up
def _Shuffle(pcoll):  # pylint: disable=invalid-name
  import random
  return (pcoll
          | 'PairWithRandom' >> beam.Map(lambda x: (random.random(), x))
          | 'GroupByRandom' >> beam.GroupByKey()
          | 'DropRandom' >> beam.FlatMap(lambda (k, vs): vs)) 
Example #19
Source File: preprocess.py    From cloudml-samples with Apache License 2.0 5 votes vote down vote up
def _Shuffle(pcoll):  # pylint: disable=invalid-name
  """Shuffles a PCollection."""
  import random
  return (pcoll
          | 'PairWithRand' >> beam.Map(lambda x: (random.random(), x))
          | 'GroupByRand' >> beam.GroupByKey()
          | 'DropRand' >> beam.FlatMap(lambda (k, vs): vs)) 
Example #20
Source File: preprocess.py    From cloudml-samples with Apache License 2.0 5 votes vote down vote up
def preprocess(pipeline, args):
  """Run pre-processing step as a pipeline.

  Args:
    pipeline: beam pipeline.
    args: parsed command line arguments.
  """
  from preproc import movielens  # pylint: disable=g-import-not-at-top

  # 1) Read the data into pcollections.
  movies_coder = tft_coders.CsvCoder(movielens.MOVIE_COLUMNS,
                                     movielens.make_movies_schema(),
                                     secondary_delimiter='|',
                                     multivalent_columns=['genres'])
  movies_data = (pipeline
                 | 'ReadMoviesData' >> beam.io.ReadFromText(
                     os.path.join(args.input_dir, 'movies.csv'),
                     coder=beam.coders.BytesCoder(),
                     # TODO(b/35653662): Obviate the need for setting this.
                     skip_header_lines=args.skip_header_lines)
                 | 'DecodeMovies' >> beam.Map(movies_coder.decode)
                 | 'KeyByMovie' >> beam.Map(lambda x: (x['movie_id'], x)))
  ratings_coder = tft_coders.CsvCoder(movielens.RATING_COLUMNS,
                                      movielens.make_ratings_schema())
  ratings_data = (pipeline
                  | 'ReadRatingsData' >> beam.io.ReadFromText(
                      os.path.join(args.input_dir, 'ratings*'),
                      skip_header_lines=args.skip_header_lines)
                  | 'DecodeRatings' >> beam.Map(ratings_coder.decode)
                  | 'KeyByUser' >> beam.Map(lambda x: (x['user_id'], x))
                  | 'GroupByUser' >> beam.GroupByKey())
  def train_eval_partition_fn((user_id, _), unused_num_partitions):
    return movielens.partition_fn(
        user_id, args.partition_random_seed, args.percent_eval)

  # Split train/eval data based on the integer user id. 
Example #21
Source File: preprocess.py    From cloudml-samples with Apache License 2.0 5 votes vote down vote up
def _Shuffle(pcoll):  # pylint: disable=invalid-name
  import random
  return (pcoll
          | 'PairWithRandom' >> beam.Map(lambda x: (random.random(), x))
          | 'GroupByRandom' >> beam.GroupByKey()
          | 'DropRandom' >> beam.FlatMap(lambda (k, vs): vs)) 
Example #22
Source File: lift_stats_generator.py    From data-validation with Apache License 2.0 5 votes vote down vote up
def expand(
      self,
      sliced_record_batchs: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
    # Compute P(Y=y)
    # _SlicedYKey(slice, y), _YRate(y_count, example_count)
    y_rates = sliced_record_batchs | 'GetYRates' >> _GetYRates(
        self._y_path, self._y_boundaries, self._weight_column_name)
    y_keys = y_rates | 'ExtractYKeys' >> beam.Keys()

    # Compute P(Y=y | X=x)
    # _SlicedYKey(slice, y), _ConditionalYRate(x_path, x, xy_count, x_count)
    conditional_y_rates = ((sliced_record_batchs, y_keys)
                           | 'GetConditionalYRates' >> _GetConditionalYRates(
                               self._y_path, self._y_boundaries, self._x_paths,
                               self._min_x_count, self._weight_column_name))

    return (
        {
            'conditional_y_rate': conditional_y_rates,
            'y_rate': y_rates
        }
        | 'CoGroupByForLift' >> beam.CoGroupByKey()
        | 'ComputeLifts' >> beam.FlatMap(_compute_lifts)
        | 'FilterLifts' >> _FilterLifts(self._top_k_per_y, self._bottom_k_per_y)
        | 'GroupLiftsForOutput' >> beam.GroupByKey()
        | 'MakeProtos' >> beam.Map(_make_dataset_feature_stats_proto,
                                   self._y_path, self._y_boundaries,
                                   self._weight_column_name is not None,
                                   self._output_custom_stats)) 
Example #23
Source File: transform.py    From pydatalab with Apache License 2.0 5 votes vote down vote up
def shuffle(pcoll):  # pylint: disable=invalid-name
  import random
  return (pcoll
          | 'PairWithRandom' >> beam.Map(lambda x: (random.random(), x))
          | 'GroupByRandom' >> beam.GroupByKey()
          | 'DropRandom' >> beam.FlatMap(lambda (k, vs): vs)) 
Example #24
Source File: transform.py    From pydatalab with Apache License 2.0 5 votes vote down vote up
def shuffle(pcoll):  # pylint: disable=invalid-name
  import random
  return (pcoll
          | 'PairWithRandom' >> beam.Map(lambda x: (random.random(), x))
          | 'GroupByRandom' >> beam.GroupByKey()
          | 'DropRandom' >> beam.FlatMap(lambda (k, vs): vs)) 
Example #25
Source File: sentiment_example.py    From transform with Apache License 2.0 5 votes vote down vote up
def Shuffle(pcoll):
  """Shuffles a PCollection.  Collection should not contain duplicates."""
  return (pcoll
          | 'PairWithHash' >> beam.Map(lambda x: (hash(x), x))
          | 'GroupByHash' >> beam.GroupByKey()
          | 'DropHash' >> beam.FlatMap(
              lambda hash_and_values: hash_and_values[1]))


# pylint: disable=invalid-name 
Example #26
Source File: deep_copy_test.py    From transform with Apache License 2.0 5 votes vote down vote up
def testBasicDeepCopy(self):
    with beam.Pipeline() as p:
      grouped = (p
                 | beam.Create([(1, 'a'), (2, 'b'), (3, 'c')])
                 | beam.Map(
                     lambda x: DeepCopyTest._CountingIdentityFn(
                         'PreGroup', x))
                 | beam.GroupByKey())
      modified = (
          grouped
          |
          'Add1' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add1'))
          |
          'Add2' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add2')))
      copied = deep_copy.deep_copy(modified)

      # pylint: disable=expression-not-assigned
      modified | 'Add3' >> beam.Map(
          DeepCopyTest._MakeAdd1CountingIdentityFn('Add3'))
      # pylint: enable=expression-not-assigned

      # Check labels.
      self.assertEqual(copied.producer.full_label, 'Add2.Copy')
      self.assertEqual(copied.producer.inputs[0].producer.full_label,
                       'Add1.Copy')

      # Check that deep copy was performed.
      self.assertIsNot(copied.producer.inputs[0], modified.producer.inputs[0])

      # Check that copy stops at materialization boundary.
      self.assertIs(copied.producer.inputs[0].producer.inputs[0],
                    modified.producer.inputs[0].producer.inputs[0])

    # Check counts of processed items.
    self.assertEqual(DeepCopyTest._counts['PreGroup'], 3)
    self.assertEqual(DeepCopyTest._counts['Add1'], 6)
    self.assertEqual(DeepCopyTest._counts['Add2'], 6)
    self.assertEqual(DeepCopyTest._counts['Add3'], 3) 
Example #27
Source File: beam_visualize.py    From exoplanet-ml with Apache License 2.0 4 votes vote down vote up
def main(unused_argv):
  stdlogging.getLogger().setLevel(stdlogging.INFO)

  def pipeline(root):
    """Beam pipeline for creating animated GIFs using an AstroWaveNet model."""
    # Read filenames of all checkpoints.
    checkpoint_state = tf.train.get_checkpoint_state(FLAGS.model_dir)
    if not checkpoint_state:
      raise ValueError("Failed to load checkpoint state from {}".format(
          FLAGS.model_dir))
    checkpoint_paths = [
        os.path.join(FLAGS.model_dir, base_name)
        for base_name in checkpoint_state.all_model_checkpoint_paths
    ]
    logging.info("Found %d checkpoints in %s", len(checkpoint_paths),
                 FLAGS.model_dir)

    # Read filenames of all input files.
    input_files = []
    for file_pattern in FLAGS.input_files.split(","):
      matches = tf.gfile.Glob(file_pattern)
      if not matches:
        raise ValueError("Found no files matching {}".format(file_pattern))
      logging.info("Reading from %d files matching %s", len(matches),
                   file_pattern)
      input_files.extend(matches)
    input_files = input_files[:1]

    # Parse model configs.
    config = configdict.ConfigDict(configurations.get_config(FLAGS.config_name))
    config_overrides = json.loads(FLAGS.config_overrides)
    for key in config_overrides:
      if key not in ["dataset", "hparams"]:
        raise ValueError("Unrecognized config override: {}".format(key))
    config.hparams.update(config_overrides.get("hparams", {}))

    # Create output directory.
    if not tf.gfile.Exists(FLAGS.output_dir):
      tf.gfile.MakeDirs(FLAGS.output_dir)

    # Initialize DoFns.
    make_predictions = prediction_fns.MakePredictionsDoFn(
        config.hparams, config_overrides.get("dataset"))
    make_animations = MakeAnimationDoFn(FLAGS.output_dir)

    # pylint: disable=expression-not-assigned
    (root | beam.Create(itertools.product(checkpoint_paths, input_files))
     | "make_predictions" >> beam.ParDo(make_predictions)
     | "group_by_example_id" >> beam.GroupByKey()
     | "make_animations" >> beam.ParDo(make_animations))
    # pylint: enable=expression-not-assigned

  pipeline.run()
  logging.info("Job completed successfully") 
Example #28
Source File: create_data.py    From conversational-datasets with Apache License 2.0 4 votes vote down vote up
def run(argv=None):
    """Run the beam pipeline."""
    args, pipeline_args = _parse_args(argv)

    pipeline_options = PipelineOptions(pipeline_args)
    pipeline_options.view_as(SetupOptions).save_main_session = True
    p = beam.Pipeline(options=pipeline_options)

    lines = p | "read qa files" >> ReadFromText(args.file_pattern)

    # The lines are not JSON, but the string representation of python
    # dictionary objects. Parse them with ast.literal_eval.
    json_objects = lines | "parsing dictionaries" >> beam.Map(ast.literal_eval)
    qa_tuples = json_objects | "create tuples" >> beam.FlatMap(
        partial(
            _create_tuples,
            min_words=args.min_words, max_words=args.max_words)
    )

    # Remove duplicate examples.
    qa_tuples |= "key by QA" >> beam.Map(lambda v: (v[1:], v))
    qa_tuples |= "group duplicates" >> beam.GroupByKey()
    qa_tuples |= "remove duplicates" >> beam.Map(lambda v: sorted(v[1])[0])

    # Create the examples.
    examples = qa_tuples | "create examples" >> beam.Map(
        lambda args: _create_example(*args)
    )
    examples = _shuffle_examples(examples)

    examples |= "split train and test" >> beam.ParDo(
        _TrainTestSplitFn(args.train_split)
    ).with_outputs(_TrainTestSplitFn.TEST_TAG, _TrainTestSplitFn.TRAIN_TAG)

    if args.dataset_format == _JSON_FORMAT:
        write_sink = WriteToText
        file_name_suffix = ".json"
        serialize_fn = json.dumps
    else:
        assert args.dataset_format == _TF_FORMAT
        write_sink = WriteToTFRecord
        file_name_suffix = ".tfrecord"
        serialize_fn = _features_to_serialized_tf_example

    for name, tag in [("train", _TrainTestSplitFn.TRAIN_TAG),
                      ("test", _TrainTestSplitFn.TEST_TAG)]:

        serialized_examples = examples[tag] | (
            "serialize {} examples".format(name) >> beam.Map(serialize_fn))
        (
            serialized_examples | ("write " + name)
            >> write_sink(
                os.path.join(args.output_dir, name),
                file_name_suffix=file_name_suffix,
                num_shards=args.num_shards_train,
            )
        )

    result = p.run()
    result.wait_until_finish() 
Example #29
Source File: jackknife.py    From model-analysis with Apache License 2.0 4 votes vote down vote up
def _make_jackknife_samples(
    slice_partitions: Tuple[slicer.SliceKeyType,
                            Sequence[_PartitionInfo]], combiner: beam.CombineFn
) -> Iterator[Tuple[slicer.SliceKeyType, 'metric_types.MetricsDict']]:
  """Computes leave-one-out and unsampled ouputs for the combiner.

  This function creates leave-one-out combiner outputs by combining all but one
  accumulator and extracting the output. Second, it creates an unsampled output
  using all of the accumulators and extracts an unsampled output. The keys
  yielded by thus function are augmented versions of the input slice key in
  which the sample ID (or a special placeholder ID for the unsampled value) has
  been added.

  Args:
    slice_partitions: The result of GroupByKey in which the key is a slice_key,
      and the grouped stream consists of per-partition _PartitionInfo tuples in
      which the first element is an accumulator for that partition, the second
      element is the size of that partition, and the third element is the
      partition ID.
    combiner: The combiner to be used for converting accumulators to outputs.

  Yields:
    Tuples of the form (slice_key, metrics), for each jackknife sample and for
    the unsampled value.
  """
  slice_key, accumulators_sizes_and_ids = slice_partitions
  accumulators, sizes, partition_ids = zip(*accumulators_sizes_and_ids)
  unsampled_accumulator = None
  for i, loo_accumulator in enumerate(
      _make_loo_accumulators(list(accumulators), combiner)):
    # yield sampled output with sample_id of the leftout partition
    sample_id_key = (_JACKKNIFE_SAMPLE_ID_KEY, partition_ids[i])
    yield slice_key + (sample_id_key,), combiner.extract_output(loo_accumulator)
    if i == 0:
      # Create the unsampled accumulator from sample 0 and its complement.
      unsampled_accumulator = combiner.merge_accumulators(
          [loo_accumulator, accumulators[0]])

  # yield unsampled output along with total count as a special metric
  count_dict = {_JACKKNIFE_EXAMPLE_COUNT_METRIC_KEY: sum(sizes)}
  sample_id_key = ((_JACKKNIFE_SAMPLE_ID_KEY, _JACKKNIFE_FULL_SAMPLE_ID),)
  unsampled_output = combiner.extract_output(unsampled_accumulator)
  unsampled_key = slice_key + sample_id_key
  unsampled_val = unsampled_output + (count_dict,)
  yield unsampled_key, unsampled_val 
Example #30
Source File: poisson_bootstrap.py    From model-analysis with Apache License 2.0 4 votes vote down vote up
def ComputeWithConfidenceIntervals(  # pylint: disable=invalid-name
    sliced_extracts: beam.pvalue.PCollection,
    compute_per_slice_metrics_cls: Type[beam.PTransform],
    num_bootstrap_samples: Optional[int] = DEFAULT_NUM_BOOTSTRAP_SAMPLES,
    random_seed_for_testing: Optional[int] = None,
    **kwargs) -> beam.pvalue.PCollection:
  """PTransform for computing metrics using T-Distribution values.

  Args:
    sliced_extracts: Incoming PCollection consisting of slice key and extracts.
    compute_per_slice_metrics_cls: PTransform class that takes a PCollection of
      (slice key, extracts) as input and returns (slice key, dict of metrics) as
      output. The class will be instantiated multiple times to compute metrics
      both with and without sampling. The class will be initialized using kwargs
      'compute_with_sampling' and 'random_seed_for_testing' along with any
      kwargs passed in **kwargs.
    num_bootstrap_samples: Number of replicas to use in calculating uncertainty
      using bootstrapping. If 1 is provided (default), aggregate metrics will be
      calculated with no uncertainty. If num_bootstrap_samples is > 0, multiple
      samples of each slice will be calculated using the Poisson bootstrap
      method. To calculate standard errors, num_bootstrap_samples should be 20
      or more in order to provide useful data. More is better, but you pay a
      performance cost.
    random_seed_for_testing: Seed to use for unit testing, because
      nondeterministic tests stink. Each partition will use this value + i.
    **kwargs: Additional args to pass to compute_per_slice_metrics_cls init.

  Returns:
    PCollection of (slice key, dict of metrics)
  """
  if not num_bootstrap_samples:
    num_bootstrap_samples = 1
  # TODO(ckuhn): Cap the number of bootstrap samples at 20.
  if num_bootstrap_samples < 1:
    raise ValueError('num_bootstrap_samples should be > 0, got %d' %
                     num_bootstrap_samples)

  output_results = (
      sliced_extracts
      | 'ComputeUnsampledMetrics' >> compute_per_slice_metrics_cls(
          compute_with_sampling=False, random_seed_for_testing=None, **kwargs))

  if num_bootstrap_samples > 1:
    multicombine = []
    for i in range(num_bootstrap_samples):
      seed = (None if random_seed_for_testing is None else
              random_seed_for_testing + i)
      multicombine.append(
          sliced_extracts
          | 'ComputeSampledMetrics%d' % i >> compute_per_slice_metrics_cls(
              compute_with_sampling=True,
              random_seed_for_testing=seed,
              **kwargs))
    output_results = (
        multicombine
        | 'FlattenBootstrapPartitions' >> beam.Flatten()
        | 'GroupBySlice' >> beam.GroupByKey()
        | 'MergeBootstrap' >> beam.ParDo(_MergeBootstrap(),
                                         beam.pvalue.AsDict(output_results)))
  return output_results