Python apache_beam.PTransform() Examples
The following are 30
code examples of apache_beam.PTransform().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
apache_beam
, or try the search function
.
Example #1
Source File: executor.py From tfx with Apache License 2.0 | 6 votes |
def __init__(self, input_cache_dir: Text, output_cache_dir: Text, analyze_data_list: List[_Dataset], feature_spec_or_typespec: Mapping[Text, Any], preprocessing_fn: Any, cache_source: beam.PTransform): # pyformat: enable self._input_cache_dir = input_cache_dir self._output_cache_dir = output_cache_dir self._analyze_data_list = analyze_data_list self._feature_spec_or_typespec = feature_spec_or_typespec self._preprocessing_fn = preprocessing_fn self._cache_source = cache_source # TODO(zoy): Remove this method once beam no longer pickles PTransforms, # once https://issues.apache.org/jira/browse/BEAM-3812 is resolved.
Example #2
Source File: tfxio.py From tfx-bsl with Apache License 2.0 | 6 votes |
def BeamSource(self, batch_size: Optional[int] = None) -> beam.PTransform: """Returns a beam `PTransform` that produces `PCollection[pa.RecordBatch]`. May NOT raise an error if the TFMD schema was not provided at construction time. If a TFMD schema was provided at construction time, all the `pa.RecordBatch`es in the result `PCollection` must be of the same schema returned by `self.ArrowSchema`. If a TFMD schema was not provided, the `pa.RecordBatch`es might not be of the same schema (they may contain different numbers of columns). Args: batch_size: if not None, the `pa.RecordBatch` produced will be of the specified size. Otherwise it's automatically tuned by Beam. """
Example #3
Source File: record_based_tfxio.py From tfx-bsl with Apache License 2.0 | 6 votes |
def RawRecordToRecordBatch(self, batch_size: Optional[int] = None ) -> beam.PTransform: """Returns a PTransform that converts raw records to Arrow RecordBatches. The input PCollection must be from self.RawRecordBeamSource() (also see the documentation for that method). Args: batch_size: if not None, the `pa.RecordBatch` produced will be of the specified size. Otherwise it's automatically tuned by Beam. """ @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(pcoll: beam.pvalue.PCollection): return (pcoll | "RawRecordToRecordBatch" >> self._RawRecordToRecordBatchInternal(batch_size) | "CollectRecordBatchTelemetry" >> telemetry.ProfileRecordBatches(self._telemetry_descriptors, self._logical_format, self._physical_format)) return beam.ptransform_fn(_PTransformFn)()
Example #4
Source File: tf_example_record.py From tfx-bsl with Apache License 2.0 | 6 votes |
def _RawRecordToRecordBatchInternal(self, batch_size: Optional[int] = None ) -> beam.PTransform: @beam.typehints.with_input_types(bytes) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection): return (raw_records_pcoll | "Batch" >> beam.BatchElements( **batch_util.GetBatchElementsKwargs(batch_size)) | "Decode" >> beam.ParDo( _DecodeBatchExamplesDoFn(self._GetSchemaForDecoding(), self.raw_record_column_name, self._can_produce_large_types))) return beam.ptransform_fn(_PTransformFn)()
Example #5
Source File: base_example_gen_executor.py From tfx with Apache License 2.0 | 6 votes |
def GetInputSourceToExamplePTransform(self) -> beam.PTransform: """Returns PTransform for converting input source to records. The record is by default assumed to be tf.train.Example protos, subclassses can serialize any protocol buffer into bytes as output PCollection, so long as the downstream component can consume it. Note that each input split will be transformed by this function separately. For complex use case, consider override 'GenerateExamplesByBeam' instead. Here is an example PTransform: @beam.ptransform_fn @beam.typehints.with_input_types(beam.Pipeline) @beam.typehints.with_output_types(Union[tf.train.Example, tf.train.SequenceExample, bytes]) def ExamplePTransform( pipeline: beam.Pipeline, exec_properties: Dict[Text, Any], split_pattern: Text) -> beam.pvalue.PCollection """ pass
Example #6
Source File: writer.py From model-analysis with Apache License 2.0 | 6 votes |
def Write(evaluation_or_validation: Union[evaluator.Evaluation, validator.Validation], key: Text, ptransform: beam.PTransform) -> beam.pvalue.PDone: """Writes given Evaluation or Validation data using given writer PTransform. Args: evaluation_or_validation: Evaluation or Validation data. key: Key for Evaluation or Validation output to write. It is valid for the key to not exist in the dict (in which case the write is a no-op). ptransform: PTransform to use for writing. Raises: ValueError: If Evaluation or Validation is empty. The key does not need to exist in the Evaluation or Validation, but the dict must not be empty. Returns: beam.pvalue.PDone. """ if not evaluation_or_validation: raise ValueError('Evaluations and Validations cannot be empty') if key in evaluation_or_validation: return evaluation_or_validation[key] | ptransform return beam.pvalue.PDone(list(evaluation_or_validation.values())[0].pipeline)
Example #7
Source File: data_linter.py From data-linter with Apache License 2.0 | 6 votes |
def expand(self, examples): """Runs the linters on the data and writes out the results. The order in which the linters run is unspecified. Args: examples: A `PTransform` that yields a `PCollection` of `tf.Examples`. Returns: A pipeline containing the `DataLinter` `PTransform`s. """ coders = (beam.coders.coders.StrUtf8Coder(), beam.coders.coders.ProtoCoder(lint_result_pb2.LintResult)) return ( [examples | linter for linter in self._linters if linter.should_run()] | 'MergeResults' >> beam.Flatten() | 'DropEmpty' >> beam.Filter(lambda (_, r): r and len(r.warnings)) | 'ToDict' >> beam.combiners.ToDict() | 'WriteResults' >> beam.io.textio.WriteToText( self._results_path, coder=beam.coders.coders.PickleCoder(), shard_name_template=''))
Example #8
Source File: linters.py From data-linter with Apache License 2.0 | 6 votes |
def expand(self, examples): """Implements the interface required by `PTransform`. Args: examples: A `PTransform` that yields a `PCollection` of tf.Examples. Returns: A `PTransform` that yields a `PCollection` containing at most one tuple in which the first element is the `LintDetector` name and the second is the `LintResult`. """ result = self._lint(examples) if not isinstance(result, (beam.pvalue.PCollection, beam.transforms.PTransform)): result_pcoll = beam.Create([result] if result else []) result = examples.pipeline | 'Materialize' >> result_pcoll return result | 'PairWithName' >> beam.Map( lambda r: (type(self).__name__, r))
Example #9
Source File: linters.py From data-linter with Apache License 2.0 | 6 votes |
def _lint(self, examples): """Returns the result of the `TokenizableStringDetector` linter. Args: examples: A `PTransform` that yields a `PCollection` of `tf.Example`s. Returns: A `LintResult` of the format warnings: [feature names] lint_samples: [{ strings=[vals..] } for each warning] """ result = self._make_result() string_features = utils.get_string_features(self._stats) for feature in self._stats.features: if feature.name not in string_features: continue str_stats = feature.string_stats if (str_stats.avg_length > self._length_threshold and str_stats.unique > self._enum_threshold): result.warnings.append(feature.name) samples = [bucket.label for bucket in str_stats.rank_histogram.buckets if len(bucket.label) > self._length_threshold] result.lint_samples.add(strings=samples[:self.N_LINT_SAMPLES]) return result
Example #10
Source File: linters.py From data-linter with Apache License 2.0 | 6 votes |
def _lint(self, examples): """Returns the result of the CircularDomainDetector linter. Args: examples: A `PTransform` that yields a `PCollection` of `tf.Example`s. Returns: A `LintResult` of the format warnings: [feature names] lint_sample: None """ result = self._make_result() numeric_features = utils.get_numeric_features(self._stats) for feature in self._stats.features: name = feature.name if name in numeric_features and self._name_is_suspicious(name): result.warnings.append(name) return result
Example #11
Source File: analyzer_cache.py From transform with Apache License 2.0 | 6 votes |
def __init__(self, pipeline, cache_base_dir, dataset_keys=None, sink=None): """Init method. Args: pipeline: A beam Pipeline. cache_base_dir: A str, the path that the cache should be stored in. dataset_keys: (Optional) An iterable of strings. sink: (Optional) A PTransform class that takes a path in its constructor, and is used to write the cache. If not provided this uses a GZipped TFRecord sink. """ self.pipeline = pipeline self._cache_base_dir = cache_base_dir if dataset_keys is None: self._sorted_dataset_keys = None else: self._sorted_dataset_keys = sorted(dataset_keys) self._sink = sink if self._sink is None: # TODO(b/37788560): Possibly use Riegeli as a default file format once # possible. self._sink = _WriteToTFRecordGzip
Example #12
Source File: stats_impl_test.py From data-validation with Apache License 2.0 | 6 votes |
def test_generate_statistics_in_memory_invalid_custom_generator( self): # Dummy PTransform that does nothing. class CustomPTransform(beam.PTransform): def expand(self, pcoll): pass record_batch = pa.RecordBatch.from_arrays([pa.array([[1.0]])], ['a']) custom_generator = stats_generator.TransformStatsGenerator( name='CustomStatsGenerator', ptransform=CustomPTransform()) options = stats_options.StatsOptions( generators=[custom_generator], enable_semantic_domain_stats=True) with self.assertRaisesRegexp( TypeError, 'Statistics generator.* found object of type ' 'TransformStatsGenerator.'): stats_impl.generate_statistics_in_memory(record_batch, options)
Example #13
Source File: linters.py From data-linter with Apache License 2.0 | 6 votes |
def _lint(self, examples): """Returns the `PTransform` for the EmptyExampleDetector linter. Args: examples: A `PTransform` that yields a `PCollection` of `tf.Example`s. Returns: A `PTransform` that yields a `LintResult` of the format warnings: [num empties] lint_sample: None """ n_empties = ( examples | 'DetectEmpties' >> beam.Map(self._example_is_empty) | 'Count' >> beam.CombineGlobally(sum) | 'NoZero' >> beam.Filter(bool) | 'ToResult' >> beam.Map( lambda w: self._make_result(warnings=[str(w)]))) return n_empties
Example #14
Source File: analyzer_cache.py From transform with Apache License 2.0 | 6 votes |
def __init__(self, cache_base_dir, dataset_keys, cache_entry_keys=None, source=None): """Init method. Args: cache_base_dir: A string, the path that the cache should be stored in. dataset_keys: An iterable of `DatasetKey`s. cache_entry_keys: (Optional) An iterable of cache entry key strings. If provided, only cache entries that exist in `cache_entry_keys` will be read. source: (Optional) A PTransform class that takes a path argument in its constructor, and is used to read the cache. """ self._cache_base_dir = cache_base_dir if not all(isinstance(d, DatasetKey) for d in dataset_keys): raise ValueError('Expected dataset_keys to be of type DatasetKey') self._sorted_dataset_keys = sorted(dataset_keys) self._filtered_cache_entry_keys = (None if cache_entry_keys is None else set(cache_entry_keys)) # TODO(b/37788560): Possibly use Riegeli as a default file format once # possible. self._source = source if source is not None else beam.io.ReadFromTFRecord
Example #15
Source File: telemetry.py From tfx-bsl with Apache License 2.0 | 5 votes |
def ProfileRecordBatches( pcoll: beam.pvalue.PCollection, telemetry_descriptors: Optional[List[Text]], logical_format: Text, physical_format: Text, distribution_update_probability: float = 0.1) -> beam.PTransform: """An identity transform to profile RecordBatches and updated Beam metrics. Args: pcoll: a PCollection[pa.RecordBatch] telemetry_descriptors: a set of descriptors that identify the component that invokes this PTransform. These will be used to construct the namespace to contain the beam metrics created within this PTransform. All such namespaces will be prefixed by "tfxio.". If None, a default "unknown" descriptor will be used. logical_format: the logical format of the data (before parsed into RecordBatches). Used to construct metric names. physical_format: the physical format in which the data is stored on disk. Used to construct metric names. distribution_update_probability: probability to update the expensive, per-row distributions. Returns: `pcoll` (identity function). """ assert 0 < distribution_update_probability <= 1.0, ( "Invalid probability: {}".format(distribution_update_probability)) return pcoll | "ProfileRecordBatches" >> beam.ParDo( _ProfileRecordBatchDoFn(telemetry_descriptors, logical_format, physical_format, distribution_update_probability))
Example #16
Source File: tfxio.py From tfx-bsl with Apache License 2.0 | 5 votes |
def BeamSource(self, batch_size: Optional[int] = None) -> beam.PTransform: return self.projected.BeamSource(batch_size)
Example #17
Source File: linters.py From data-linter with Apache License 2.0 | 5 votes |
def _lint(self, examples): """Returns the result of the TailedDistributionDetector linter. Args: examples: A `PTransform` that yields a `PCollection` of `tf.Example`s. Returns: A `PTransform` that yields a `LintResult` of the format warnings: [feature names] lint_samples: [ [stats: {min: feature_min if outlying, max: feature_max if outlying}] for each warning ] """ feature_values = ( examples | 'FlattenFeatureValue' >> beam.FlatMap( self._flatten_feature_vals(self.numeric_features))) feature_min_trimmed_mean = ( feature_values | self._make_trimmed_averager(self._MIN)) feature_max_trimmed_mean = ( feature_values | self._make_trimmed_averager(self._MAX)) return ( (feature_min_trimmed_mean, feature_max_trimmed_mean) | 'MergeTrimmedMeans' >> beam.CoGroupByKey() | 'AsList' >> beam.combiners.ToList() | 'ToResult' >> beam.Map(self._to_result))
Example #18
Source File: record_based_tfxio.py From tfx-bsl with Apache License 2.0 | 5 votes |
def _RawRecordBeamSourceInternal(self) -> beam.PTransform: """Returns a PTransform that produces a PCollection[bytes]."""
Example #19
Source File: record_based_tfxio.py From tfx-bsl with Apache License 2.0 | 5 votes |
def _RawRecordToRecordBatchInternal(self, batch_size: Optional[int] = None ) -> beam.PTransform: """Returns a PTransform that converts raw records to Arrow RecordBatches.""" pass
Example #20
Source File: record_based_tfxio.py From tfx-bsl with Apache License 2.0 | 5 votes |
def BeamSource(self, batch_size: Optional[int] = None) -> beam.PTransform: @beam.typehints.with_input_types(beam.Pipeline) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(pipeline: beam.pvalue.PCollection): """Converts raw records to RecordBatches.""" return ( pipeline | "RawRecordBeamSource" >> self.RawRecordBeamSource() | "RawRecordToRecordBatch" >> self.RawRecordToRecordBatch(batch_size)) return beam.ptransform_fn(_PTransformFn)()
Example #21
Source File: csv_tfxio.py From tfx-bsl with Apache License 2.0 | 5 votes |
def _CSVSource(self) -> beam.PTransform: """Returns a PTtransform that producese PCollection[bytets]."""
Example #22
Source File: csv_tfxio.py From tfx-bsl with Apache License 2.0 | 5 votes |
def _RawRecordToRecordBatchInternal(self, batch_size: Optional[int] = None ) -> beam.PTransform: @beam.typehints.with_input_types(List[bytes]) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(raw_records_pcoll: beam.pvalue.PCollection): """Returns RecordBatch of csv lines.""" # Decode raw csv lines to record batches. record_batches = ( raw_records_pcoll | "CSVToRecordBatch" >> csv_decoder.CSVToRecordBatch( column_names=self._column_names, delimiter=self._delimiter, skip_blank_lines=self._skip_blank_lines, schema=self._schema, desired_batch_size=batch_size, multivalent_columns=self._multivalent_columns, secondary_delimiter=self._secondary_delimiter, produce_large_types=self._can_produce_large_types, raw_record_column_name=self._raw_record_column_name)) return record_batches return beam.ptransform_fn(_PTransformFn)()
Example #23
Source File: csv_tfxio.py From tfx-bsl with Apache License 2.0 | 5 votes |
def _CSVSource(self) -> beam.PTransform: """Returns a PTtransform that producese PCollection[bytes].""" return beam.io.ReadFromText( self._file_pattern, coder=beam.coders.BytesCoder(), validate=self._validate)
Example #24
Source File: telemetry.py From tfx-bsl with Apache License 2.0 | 5 votes |
def ProfileRawRecords( pcoll: beam.pvalue.PCollection, telemetry_descriptors: Optional[List[Text]], logical_format: Text, physical_format: Text) -> beam.PTransform: """An identity transform to profile raw records for record based TFXIO.""" return pcoll | "ProfileRawRecords" >> beam.ParDo(_ProfileRawRecordDoFn( telemetry_descriptors, logical_format, physical_format))
Example #25
Source File: linters.py From data-linter with Apache License 2.0 | 5 votes |
def _count_transformer(self): """Returns a `PTransform` that modifies the raw feature-value counts. The `PTransform` will receive as its pipeline input a `PCollection` containing entries of the format ((feature_name, feature_val), count) and must produce a `PCollection` containing entries of the same format. """ raise NotImplementedError()
Example #26
Source File: verifier_lib.py From model-analysis with Apache License 2.0 | 5 votes |
def Validate( # pylint: disable=invalid-name extracts: beam.pvalue.PCollection, alternatives: Dict[Text, beam.PTransform], validators: List[validator.Validator]) -> validator.Validation: """Performs validation of alternative evaluations. Args: extracts: PCollection of extracts. alternatives: Dict of PTransforms (Extracts -> Evaluation) whose output will be compared for validation purposes (e.g. 'baseline' vs 'candidate'). validators: List of validators for validating the output from running the alternatives. The Validation outputs produced by the validators will be merged into a single output. If there are overlapping output keys, later outputs will replace earlier outputs sharing the same key. Returns: Validation dict. """ evaluations = {} for key in alternatives: evaluations[key] = extracts | 'Evaluate(%s)' % key >> alternatives[key] validation = {} for v in validators: validation.update(evaluations | v.stage_name >> v.ptransform) return validation
Example #27
Source File: raw_tf_record.py From tfx-bsl with Apache License 2.0 | 5 votes |
def _RawRecordToRecordBatchInternal(self, batch_size: Optional[int] = None ) -> beam.PTransform: @beam.typehints.with_input_types(beam.Pipeline) @beam.typehints.with_output_types(pa.RecordBatch) def _PTransformFn(raw_record_pcoll: beam.pvalue.PCollection): return (raw_record_pcoll | "Batch" >> beam.BatchElements( **batch_util.GetBatchElementsKwargs(batch_size)) | "ToRecordBatch" >> beam.Map(_BatchedRecordsToArrow, self.raw_record_column_name, self._can_produce_large_types)) return beam.ptransform_fn(_PTransformFn)()
Example #28
Source File: linters.py From data-linter with Apache License 2.0 | 5 votes |
def _lint(self, examples): """Returns the result of the `NumberAsStringDetector` linter. Args: examples: A `PTransform` that yields a `PCollection` of `tf.Example`s Returns: A `LintResult` of the format warnings: [feature names] lint_samples: [{ strings=[vals..] } for each warning] """ result = self._make_result() string_features = utils.get_string_features(self._stats) lint_samples = collections.defaultdict(set) for feature in self._stats.features: if feature.name not in string_features: continue str_stats = feature.string_stats n_samples = str_stats.common_stats.num_non_missing if n_samples == 0: continue num_numeric = 0 for bucket in str_stats.rank_histogram.buckets: try: nums_only = re.sub(r'\D', '', bucket.label) if len(nums_only) / len(bucket.label) >= 1 - self._non_num_tol: num_numeric += bucket.sample_count samples = lint_samples[feature.name] if len(samples) < self.N_LINT_SAMPLES: samples.add(bucket.label) except (ValueError, ZeroDivisionError): pass if num_numeric / n_samples > 0.5: result.warnings.append(feature.name) result.lint_samples.add(strings=lint_samples[feature.name]) return result
Example #29
Source File: linters.py From data-linter with Apache License 2.0 | 5 votes |
def _lint(self, examples): """Performs linting and returns the result. This must be implemented by `LintDetector` subclasses. Args: examples: A `PTransform` that yields a `PCollection` of `tf.Example`s. Returns: If this linter has results, this method must return either a `LintResult` or a `PTransform` that yields a `PCollection` containing exactly one. Otherwise, this function may return None, an empty `PCollection`, or a `LintResult` with an empty `warnings` list. """ raise NotImplementedError()
Example #30
Source File: linters.py From data-linter with Apache License 2.0 | 5 votes |
def _lint(self, examples): """Returns the result of the `DateTimeAsStringDetector` linter. Args: examples: A `PTransform` that yields a `PCollection` of `tf.Example`s. Returns: A `LintResult` of the format warnings: [feature names] lint_sample: [{ strings=[vals..] } for each warning] """ result = self._make_result() string_features = utils.get_string_features(self._stats) lint_samples = collections.defaultdict(set) for feature in self._stats.features: if feature.name not in string_features: continue str_stats = feature.string_stats n_samples = str_stats.common_stats.num_non_missing if n_samples == 0: continue num_date_parsable = 0 for bucket in str_stats.rank_histogram.buckets: if self._string_is_datetime(bucket.label): num_date_parsable += bucket.sample_count samples = lint_samples[feature.name] if len(samples) < self.N_LINT_SAMPLES: samples.add(bucket.label) if num_date_parsable / n_samples > 0.5: result.warnings.append(feature.name) result.lint_samples.add(strings=lint_samples[feature.name]) return result