Python tensorflow_datasets.as_numpy() Examples
The following are 24
code examples of tensorflow_datasets.as_numpy().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow_datasets
, or try the search function
.
Example #1
Source File: cache_tasks_main.py From text-to-text-transfer-transformer with Apache License 2.0 | 6 votes |
def _emit_tokenized_examples(self, shard_instruction): """Emits examples keyed by shard path and index for a single shard.""" _import_modules(self._modules_to_import) logging.info("Processing shard: %s", shard_instruction) self._increment_counter("input-shards") ds = self._task.tfds_dataset.load_shard(shard_instruction) if self._max_input_examples: num_shard_examples = int( self._max_input_examples / len(self.files)) ds = ds.repeat().take(num_shard_examples) ds = self._task.preprocess_text(ds) ds = t5.data.encode_string_features( ds, self._task.output_features, keys=self._task.output_features, copy_plaintext=True) for ex in tfds.as_numpy(ds): self._increment_counter("examples") yield ex
Example #2
Source File: test_utils.py From text-to-text-transfer-transformer with Apache License 2.0 | 6 votes |
def _get_comparable_examples_from_ds(ds): """Puts dataset into format that allows examples to be compared in Py2/3.""" examples = [] def _clean_value(v): if isinstance(v, bytes): return tf.compat.as_text(v) if isinstance(v, np.ndarray): if isinstance(v[0], bytes): return tuple(tf.compat.as_text(s) for s in v) return tuple(v) return v for ex in tfds.as_numpy(ds): examples.append( tuple((k, _clean_value(v)) for k, v in sorted(ex.items()))) return examples
Example #3
Source File: tf_inputs_test.py From trax with Apache License 2.0 | 6 votes |
def test_generic_text_dataset_preprocess_fn(self): dataset = _load_dataset('squad') example, = tfds.as_numpy(dataset.take(1)) self.assertNotIn('inputs', example) self.assertNotIn('targets', example) proc_dataset = tf_inputs.generic_text_dataset_preprocess_fn( dataset, spm_path=_spm_path(), text_preprocess_fns=[lambda ds, training: t5_processors.squad(ds)], copy_plaintext=True, debug_print_examples=True, debug_print_examples_rate=1.0) proc_example, = tfds.as_numpy(proc_dataset.take(1)) self.assertIn('inputs', proc_example) self.assertIn('targets', proc_example) self.assertEqual(proc_example['inputs'].dtype, np.int64) self.assertEqual(proc_example['targets'].dtype, np.int64) # TODO(afrozm): Why does this test take so much time?
Example #4
Source File: utils.py From text-to-text-transfer-transformer with Apache License 2.0 | 6 votes |
def _log_padding_fractions(dataset, sequence_length, num_examples=100): """Empirically compute the fraction of padding - log the results. Args: dataset: a tf.data.Dataset sequence_length: dict from string to int (packed lengths) num_examples: an integer """ logging.info("computing padding fractions") keys = sequence_length.keys() padding_frac = {k: 0 for k in keys} for ex in tfds.as_numpy(dataset.take(num_examples)): for k in keys: padding_frac[k] += 1 - (sequence_length[k] / len(ex[k])) for k in keys: logging.info("%s padding fraction = %g", k, padding_frac[k])
Example #5
Source File: pixelcnn.py From jaxnet with Apache License 2.0 | 5 votes |
def dataset(batch_size): import tensorflow_datasets as tfds import tensorflow as tf tf.random.set_random_seed(0) cifar = tfds.load('cifar10') def get_train_batches(): return tfds.as_numpy(cifar['train'].map(lambda el: tf.cast(el['image'], image_dtype)). shuffle(1000).batch(batch_size).prefetch(1)) test_batches = tfds.as_numpy(cifar['test'].map(lambda el: tf.cast(el['image'], image_dtype)). repeat().shuffle(1000).batch(batch_size).prefetch(1)) return get_train_batches, test_batches
Example #6
Source File: hf_model.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def tokens_to_batches(dataset, sequence_length, batch_size, output_features): """Convert a dataset of token sequences to batches of padded/masked examples. Args: dataset: tf.data.Dataset containing examples with token sequences. sequence_length: dict of int, a dict mapping feature name to length. batch_size: int, the number of padded sequences in each batch. output_features: list of str, features to include in the dataset. Returns: A generator that produces batches of numpy examples. """ dataset = transformer_dataset.pack_or_pad( dataset, sequence_length, pack=False, feature_keys=output_features, ensure_eos=True, ) def _map_fn(ex): for key in output_features: tensor = ex[key] mask = tf.cast(tf.greater(tensor, 0), tensor.dtype) ex[key + "_mask"] = mask return ex dataset = dataset.map( _map_fn, num_parallel_calls=t5.data.preprocessors.num_parallel_calls() ) dataset = dataset.batch(batch_size, drop_remainder=False) return tfds.as_numpy(dataset)
Example #7
Source File: mesh_transformer_test.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def verify_mesh_dataset_fn(self, mixture_name, train, use_cached): if train: dataset_fn = mesh_transformer.mesh_train_dataset_fn split = tfds.Split.TRAIN else: dataset_fn = mesh_transformer.mesh_eval_dataset_fn split = tfds.Split.VALIDATION vocabulary = t5.data.MixtureRegistry.get(mixture_name).get_vocabulary() sequence_length = {"inputs": 13, "targets": 13} output = dataset_fn( mixture_name, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=split, use_cached=use_cached) if train: ds = output self.check_ds_shape(ds, sequence_length) # Materialize a few batches to test for errors. list(zip(range(10), tfds.as_numpy(ds))) else: self.assertLen(output, 1) output = output[0] (name, dsfn, postprocess_fn, metric_fns) = output self.assertEqual("cached_task" if use_cached else "uncached_task", name) ds = dsfn() self.check_ds_shape(ds, sequence_length) # No postprocess_fn is supplied so it should function as a pass-through self.assertEqual("test", postprocess_fn("test")) # test_utils task has empty metric_fns list self.assertEqual([], metric_fns) # Materialize the full dataset to test for errors. list(tfds.as_numpy(ds))
Example #8
Source File: test_utils.py From text-to-text-transfer-transformer with Apache License 2.0 | 5 votes |
def dataset_as_text(ds): for ex in tfds.as_numpy(ds): yield {k: _maybe_as_text(v) for k, v in ex.items()}
Example #9
Source File: lm_dpsgd_tutorial.py From privacy with Apache License 2.0 | 5 votes |
def load_data(): """Load training and validation data.""" if not FLAGS.data_dir: print('FLAGS.data_dir containing train.txt and test.txt was not specified, ' 'using a substitute dataset from the tensorflow_datasets module.') train_dataset = tfds.load(name='lm1b/subwords8k', split=tfds.Split.TRAIN, batch_size=NB_TRAIN, shuffle_files=True) test_dataset = tfds.load(name='lm1b/subwords8k', split=tfds.Split.TEST, batch_size=10000) train_data = next(tfds.as_numpy(train_dataset)) test_data = next(tfds.as_numpy(test_dataset)) train_data = train_data['text'].flatten() test_data = test_data['text'].flatten() else: train_fpath = os.path.join(FLAGS.data_dir, 'train.txt') test_fpath = os.path.join(FLAGS.data_dir, 'test.txt') train_txt = open(train_fpath).read().split() test_txt = open(test_fpath).read().split() keys = sorted(set(train_txt)) remap = {k: i for i, k in enumerate(keys)} train_data = np.array([remap[x] for x in train_txt], dtype=np.uint8) test_data = np.array([remap[x] for x in test_txt], dtype=np.uint8) return train_data, test_data
Example #10
Source File: loaders.py From neural-structured-learning with Apache License 2.0 | 5 votes |
def load_data_tf_datasets(dataset_name, target_num_train_per_class, target_num_val, seed): """Load and preprocess data from tensorflow_datasets.""" logging.info('Loading and preprocessing data from tensorflow datasets...') # Load train data. ds = tfds.load(dataset_name, split=tfds.Split.TRAIN, batch_size=-1) ds = tfds.as_numpy(ds) train_inputs, train_labels = ds['image'], ds['label'] # Load test data. ds = tfds.load(dataset_name, split=tfds.Split.TEST, batch_size=-1) ds = tfds.as_numpy(ds) test_inputs, test_labels = ds['image'], ds['label'] # Remove extra dimensions of size 1. train_labels = np.squeeze(train_labels) test_labels = np.squeeze(test_labels) logging.info('Splitting data...') data = split_train_val_unlabeled(train_inputs, train_labels, target_num_train_per_class, target_num_val, seed) train_inputs = data[0] train_labels = data[1] val_inputs = data[2] val_labels = data[3] unlabeled_inputs = data[4] unlabeled_labels = data[5] logging.info('Converting data to Dataset format...') data = Dataset.build_from_splits( name=dataset_name, inputs_train=train_inputs, labels_train=train_labels, inputs_val=val_inputs, labels_val=val_labels, inputs_test=test_inputs, labels_test=test_labels, inputs_unlabeled=unlabeled_inputs, labels_unlabeled=unlabeled_labels, feature_preproc_fn=convert_image) return data
Example #11
Source File: mnist_classifier.py From jaxnet with Apache License 2.0 | 5 votes |
def mnist(): # https://github.com/google/jax/blob/master/docs/gpu_memory_allocation.rst import tensorflow as tf tf.config.experimental.set_visible_devices([], "GPU") import tensorflow_datasets as tfds dataset = tfds.load("mnist:1.0.0") images = lambda d: np.reshape(np.float32(d['image']) / 256, (-1, 784)) labels = lambda d: _one_hot(d['label'], 10) train = next(tfds.as_numpy(dataset['train'].shuffle(50000).batch(50000))) test = next(tfds.as_numpy(dataset['test'].batch(10000))) return images(train), labels(train), images(test), labels(test)
Example #12
Source File: inputs_test.py From BERT with Apache License 2.0 | 5 votes |
def test_batch_fun(self): dataset = test_dataset_ints([32]) dataset = dataset.repeat(10) batches = inputs.batch_fun( dataset, True, ([None], [None]), [], 1, batch_size=10) count = 0 for example in tfds.as_numpy(batches): count += 1 self.assertEqual(example[0].shape[0], 10) # Batch size = 10. self.assertEqual(count, 1) # Just one batch here.
Example #13
Source File: mnist_vae.py From jaxnet with Apache License 2.0 | 5 votes |
def mnist_images(): # https://github.com/google/jax/blob/master/docs/gpu_memory_allocation.rst import tensorflow as tf tf.config.experimental.set_visible_devices([], "GPU") import tensorflow_datasets as tfds prep = lambda d: np.reshape(np.float32(next(tfds.as_numpy(d))['image']) / 256, (-1, 784)) dataset = tfds.load("mnist:1.0.0") return (prep(dataset['train'].shuffle(50000).batch(50000)), prep(dataset['test'].batch(10000)))
Example #14
Source File: data.py From magenta with Apache License 2.0 | 5 votes |
def count_examples(examples_path, tfds_name, data_converter, file_reader=tf.python_io.tf_record_iterator): """Counts the number of examples produced by the converter from files.""" def _file_generator(): filenames = tf.gfile.Glob(examples_path) for f in filenames: tf.logging.info('Counting examples in %s.', f) reader = file_reader(f) for item_str in reader: yield data_converter.str_to_item_fn(item_str) def _tfds_generator(): ds = tfds.as_numpy( tfds.load(tfds_name, split=tfds.Split.VALIDATION, try_gcs=True)) # TODO(adarob): Generalize to other data types if needed. for ex in ds: yield note_seq.midi_to_note_sequence(ex['midi']) num_examples = 0 generator = _tfds_generator if tfds_name else _file_generator for item in generator(): tensors = data_converter.to_tensors(item) num_examples += len(tensors.inputs) tf.logging.info('Total examples: %d', num_examples) return num_examples
Example #15
Source File: data.py From cnn-svm with Apache License 2.0 | 5 votes |
def load_tfds( name: str = "mnist" ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Returns a data set from `tfds`. Parameters ---------- name : str The name of the TensorFlow data set to load. Returns ------- train_features : np.ndarray The train features. test_features : np.ndarray The test features. train_labels : np.ndarray The train labels. test_labels : np.ndarray The test labels. """ train_dataset = tfds.load(name=name, split=tfds.Split.TRAIN, batch_size=-1) train_dataset = tfds.as_numpy(train_dataset) train_features = train_dataset["image"] train_labels = train_dataset["label"] train_features = train_features.astype("float32") train_features = train_features / 255.0 test_dataset = tfds.load(name=name, split=tfds.Split.TEST, batch_size=-1) test_dataset = tfds.as_numpy(test_dataset) test_features = test_dataset["image"] test_labels = test_dataset["label"] test_features = test_features.astype("float32") test_features = test_features / 255.0 return train_features, test_features, train_labels, test_labels
Example #16
Source File: tf_inputs.py From trax with Apache License 2.0 | 5 votes |
def _train_and_eval_dataset_v1(problem_name, data_dir, train_shuffle_files, eval_shuffle_files): """Return train and evaluation datasets, feature info and supervised keys.""" with tf.device('cpu:0'): problem = t2t_problems().problem(problem_name) hparams = None if problem_name == 'video_bair_robot_pushing': hparams = problem.get_hparams() bair_robot_pushing_hparams(hparams) train_dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, data_dir, shuffle_files=train_shuffle_files, hparams=hparams) train_dataset = train_dataset.map(_select_features) eval_dataset = problem.dataset(tf.estimator.ModeKeys.EVAL, data_dir, shuffle_files=eval_shuffle_files, hparams=hparams) eval_dataset = eval_dataset.map(_select_features) # TODO(lukaszkaiser): remove this need for one example, just input_key. examples = list(tfds.as_numpy(train_dataset.take(1))) # We use 'inputs' as input except for purely auto-regressive tasks like # language models where 'targets' are used as input_key. input_key = 'inputs' if 'inputs' in examples[0] else 'targets' supervised_keys = ([input_key], ['targets']) return train_dataset, eval_dataset, supervised_keys # Makes the function accessible in gin configs, even with all args blacklisted.
Example #17
Source File: tf_inputs_test.py From trax with Apache License 2.0 | 5 votes |
def test_c4_bare_preprocess_fn_denoising_objective(self): _t5_gin_config() dataset = _c4_dataset() dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) example = list(tfds.as_numpy(dataset.take(1)))[0] # Assertions now. self.assertIn('targets', example) targets = example['targets'] self.assertIsInstance(targets, np.ndarray) self.assertEqual(targets.dtype, np.int64) self.assertGreater(len(targets), 0) self.assertIn('inputs', example) _inputs = example['inputs'] # pylint: disable=invalid-name self.assertIsInstance(_inputs, np.ndarray) self.assertEqual(_inputs.dtype, np.int64) self.assertGreater(len(_inputs), 0) # WHP inputs will have the bulk of the text. self.assertGreater(len(_inputs), len(targets)) # WHP there will be two sentinel tokens in the inputs and targets. inputs_counter = collections.Counter(_inputs.tolist()) targets_counter = collections.Counter(targets.tolist()) self.assertEqual(1, inputs_counter[31999]) self.assertEqual(1, inputs_counter[31998]) self.assertEqual(1, targets_counter[31999]) self.assertEqual(1, targets_counter[31998])
Example #18
Source File: tf_inputs_test.py From trax with Apache License 2.0 | 5 votes |
def test_c4_bare_preprocess_fn(self): dataset = _c4_dataset() example = list(tfds.as_numpy(dataset.take(1)))[0] # Targets are NOT in the example. self.assertNotIn('targets', example) self.assertIn('text', example) text = example['text'] # This should convert the dataset to an inputs/targets that are tokenized. dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) example = list(tfds.as_numpy(dataset.take(1)))[0] # Earlier text is now stored in targets_plaintext self.assertIn('targets_plaintext', example) self.assertEqual(example['targets_plaintext'], text) # Targets are now tokenized. self.assertIn('targets', example) self.assertIsInstance(example['targets'], np.ndarray) self.assertEqual(example['targets'].dtype, np.int64) self.assertGreater(len(example['targets']), 0) self.assertEqual(example['targets'][-1], 1) # we add EOS at the end. # Inputs exist but is empty because t5 preprocessors' unsupervised wasn't # gin configured with any. self.assertIn('inputs', example) self.assertEqual(len(example['inputs']), 0)
Example #19
Source File: shapenet_test.py From graphics with Apache License 2.0 | 5 votes |
def test_dataset_items(self): builder = shapenet.Shapenet(data_dir=self.tmp_dir) self._download_and_prepare_as_dataset(builder) for split_name in self.SPLITS: items = tfds.as_numpy(builder.as_dataset(split=split_name)) for item in items: expected = self.EXPECTED_ITEMS[split_name][item['model_id']] self.assertEqual(item['label'], self._encode_synset(builder, expected['synset'])) self.assertLen(item['trimesh']['vertices'], expected['num_vertices']) self.assertLen(item['trimesh']['faces'], expected['num_faces'])
Example #20
Source File: inputs_test.py From BERT with Apache License 2.0 | 5 votes |
def test_batch_fun_n_devices(self): dataset = test_dataset_ints([32]) dataset = dataset.repeat(9) batches = inputs.batch_fun( dataset, True, ([None], [None]), [], 9, batch_size=10) count = 0 for example in tfds.as_numpy(batches): count += 1 # Batch size adjusted to be divisible by n_devices. self.assertEqual(example[0].shape[0], 9) self.assertEqual(count, 1) # Just one batch here.
Example #21
Source File: get_data.py From qkeras with Apache License 2.0 | 4 votes |
def get_data(dataset_name, fast=False): """Returns dataset from tfds.""" ds_train = tfds.load(name=dataset_name, split="train", batch_size=-1) ds_test = tfds.load(name=dataset_name, split="test", batch_size=-1) dataset = tfds.as_numpy(ds_train) x_train, y_train = dataset["image"].astype(np.float32), dataset["label"] dataset = tfds.as_numpy(ds_test) x_test, y_test = dataset["image"].astype(np.float32), dataset["label"] if len(x_train.shape) == 3: x_train = x_train.reshape(x_train.shape + (1,)) x_test = x_test.reshape(x_test.shape + (1,)) x_train /= 256.0 x_test /= 256.0 x_mean = np.mean(x_train, axis=0) x_train -= x_mean x_test -= x_mean nb_classes = np.max(y_train) + 1 y_train = to_categorical(y_train, nb_classes) y_test = to_categorical(y_test, nb_classes) print(x_train.shape[0], "train samples") print(x_test.shape[0], "test samples") if fast: i_train = np.arange(x_train.shape[0]) np.random.shuffle(i_train) i_test = np.arange(x_test.shape[0]) np.random.shuffle(i_test) s_x_train = x_train[i_train[0:fast]] s_y_train = y_train[i_train[0:fast]] s_x_test = x_test[i_test[0:fast]] s_y_test = y_test[i_test[0:fast]] return ((s_x_train, s_y_train), (x_train, y_train), (s_x_test, s_y_test), (x_test, y_test)) else: return (x_train, y_train), (x_test, y_test)
Example #22
Source File: utils_test.py From text-to-text-transfer-transformer with Apache License 2.0 | 4 votes |
def test_invalid_token_preprocessors(self): def _dummy_preprocessor(output): return lambda _, **unused: tf.data.Dataset.from_tensors(output) i64_arr = lambda x: np.array(x, dtype=np.int64) def _materialize(task): list(tfds.as_numpy(TaskRegistry.get_dataset( task, {"inputs": 13, "targets": 13}, "train", use_cached=False))) test_utils.add_tfds_task( "token_prep_ok", token_preprocessor=_dummy_preprocessor( {"inputs": i64_arr([2, 3]), "targets": i64_arr([3]), "other": "test"})) _materialize("token_prep_ok") test_utils.add_tfds_task( "token_prep_missing_feature", token_preprocessor=_dummy_preprocessor({"inputs": i64_arr([2, 3])})) with self.assertRaisesRegex( ValueError, "Task dataset is missing expected output feature after token " "preprocessing: targets"): _materialize("token_prep_missing_feature") test_utils.add_tfds_task( "token_prep_wrong_type", token_preprocessor=_dummy_preprocessor( {"inputs": "a", "targets": i64_arr([3])})) with self.assertRaisesRegex( ValueError, "Task dataset has incorrect type for feature 'inputs' after token " "preprocessing: Got string, expected int64"): _materialize("token_prep_wrong_type") test_utils.add_tfds_task( "token_prep_wrong_shape", token_preprocessor=_dummy_preprocessor( {"inputs": i64_arr([2, 3]), "targets": i64_arr(1)})) with self.assertRaisesRegex( ValueError, "Task dataset has incorrect rank for feature 'targets' after token " "preprocessing: Got 0, expected 1"): _materialize("token_prep_wrong_shape") test_utils.add_tfds_task( "token_prep_has_eos", token_preprocessor=_dummy_preprocessor( {"inputs": i64_arr([1, 3]), "targets": i64_arr([4])})) with self.assertRaisesRegex( tf.errors.InvalidArgumentError, r".*Feature \\'inputs\\' unexpectedly contains EOS=1 token after token " r"preprocessing\..*"): _materialize("token_prep_has_eos")
Example #23
Source File: utils.py From text-to-text-transfer-transformer with Apache License 2.0 | 4 votes |
def _log_mixing_proportions( tasks, datasets, rates, mixed_dataset, sequence_length, compute_stats_empirically): """Log information about the mixing proportions. Called from Mixture.get_dataset. Args: tasks: a list of Task datasets: a list of tf.data.Dataset rates: a list of floats mixed_dataset: a tf.data.Dataset sequence_length: dict from string to int (packed lengths) compute_stats_empirically: a boolean - does not work on TPU """ def _normalize(l): denom = sum(l) return [x / denom for x in l] # compute some stats about the mixture examples_fraction = _normalize(rates) if compute_stats_empirically: stats_examples = 100 mean_inputs_length = [] mean_targets_length = [] for dataset in datasets: inputs_sum = 0 targets_sum = 0 for ex in tfds.as_numpy(dataset.take(stats_examples)): inputs_sum += ex["inputs"].size targets_sum += ex["targets"].size mean_inputs_length.append(inputs_sum / float(stats_examples)) mean_targets_length.append(targets_sum / float(stats_examples)) else: def _estimated_mean_length(task, key): if task.token_preprocessor: return sequence_length[key] else: return min(sequence_length[key], (task.get_cached_stats("train")[key + "_tokens"] / task.get_cached_stats("train")["examples"])) mean_inputs_length = [_estimated_mean_length(task, "inputs") for task in tasks] mean_targets_length = [_estimated_mean_length(task, "targets") for task in tasks] inputs_fraction = _normalize( [l * r for l, r in zip(mean_inputs_length, rates)]) targets_fraction = _normalize( [l * r for l, r in zip(mean_targets_length, rates)]) logging.info("%12s %12s %12s %12s %12s %12s %s", "rate", "ex.frac.", "inp.frac.", "tgt.frac.", "inp.len.", "tgt.len", "task") for i in range(len(rates)): logging.info("%12g %12g %12g %12g %12g %12g %s", rates[i], examples_fraction[i], inputs_fraction[i], targets_fraction[i], mean_inputs_length[i], mean_targets_length[i], tasks[i].name) if compute_stats_empirically: _log_padding_fractions(mixed_dataset, sequence_length)
Example #24
Source File: tf_inputs_test.py From trax with Apache License 2.0 | 4 votes |
def test_c4_preprocess(self): def load_c4_dataset(split='train'): dataset = _c4_dataset(split=split) return dataset.map(lambda example: (example, example['text'])) def examine_processed_dataset(proc_dataset): count = 0 lengths = [] for example in tfds.as_numpy(proc_dataset): count += 1 ex = example[0] # Targets are in the example. self.assertIn('targets', ex) self.assertEqual(ex['targets'].dtype, np.int64) lengths.append(len(ex['targets'])) return count, lengths unfiltered_count = 0 for example in tfds.as_numpy(load_c4_dataset()): unfiltered_count += 1 # Targets are NOT in the example. self.assertNotIn('targets', example[0]) proc_dataset = tf_inputs.c4_preprocess(load_c4_dataset(), False, 2048) # `examine_processed_dataset` has some asserts in it. proc_count, char_lengths = examine_processed_dataset(proc_dataset) # Both the original and filtered datasets have examples. self.assertGreater(unfiltered_count, 0) self.assertGreater(proc_count, 0) # Because we filter out some entries on length. self.assertLess(proc_count, unfiltered_count) # Preprocess using the sentencepiece model in testdata. spc_proc_dataset = tf_inputs.c4_preprocess( load_c4_dataset(), False, 2048, tokenization='spc', spm_path=_spm_path()) spc_proc_count, spc_lengths = examine_processed_dataset(spc_proc_dataset) # spc shortens the target sequence a lot, should be almost equal to # unfiltered self.assertLessEqual(proc_count, spc_proc_count) self.assertEqual(unfiltered_count, spc_proc_count) # Assert all spc_lengths are lesser than their char counterparts. for spc_len, char_len in zip(spc_lengths, char_lengths): self.assertLessEqual(spc_len, char_len)