Python tensorflow.Dataset() Examples
The following are 30
code examples of tensorflow.Dataset().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: datasets.py From deepchem with MIT License | 6 votes |
def transform(self, fn, **args): """Construct a new dataset by applying a transformation to every sample in this dataset. The argument is a function that can be called as follows: >> newx, newy, neww = fn(x, y, w) It might be called only once with the whole dataset, or multiple times with different subsets of the data. Each time it is called, it should transform the samples and return the transformed data. Parameters ---------- fn: function A function to apply to each sample in the dataset Returns ------- a newly constructed Dataset object """ raise NotImplementedError()
Example #2
Source File: datasets.py From deepchem with MIT License | 6 votes |
def make_pytorch_dataset(self, epochs=1, deterministic=False): """Create a torch.utils.data.IterableDataset that iterates over the data in this Dataset. Each value returned by the Dataset's iterator is a tuple of (X, y, w, id) for one sample. Parameters ---------- epochs: int the number of times to iterate over the Dataset deterministic: bool if True, the data is produced in order. If False, a different random permutation of the data is used for each epoch. Returns ------- `torch.utils.data.IterableDataset` that iterates over the data in this dataset. """ raise NotImplementedError()
Example #3
Source File: datasets.py From deepchem with MIT License | 6 votes |
def transform(self, fn, **args): """Construct a new dataset by applying a transformation to every sample in this dataset. The argument is a function that can be called as follows: >> newx, newy, neww = fn(x, y, w) It might be called only once with the whole dataset, or multiple times with different subsets of the data. Each time it is called, it should transform the samples and return the transformed data. Parameters ---------- fn: function A function to apply to each sample in the dataset Returns ------- a newly constructed Dataset object """ newx, newy, neww = fn(self._X, self._y, self._w) return NumpyDataset(newx, newy, neww, self._ids[:])
Example #4
Source File: tf_runner.py From ray with Apache License 2.0 | 6 votes |
def __init__(self, model_creator, data_creator, config=None, verbose=False): """Initializes the runner. Args: model_creator (dict -> Model): see tf_trainer.py. data_creator (dict -> tf.Dataset, tf.Dataset): see tf_trainer.py. config (dict): see tf_trainer.py. verbose (bool): Outputs training data if true. """ self.model_creator = model_creator self.data_creator = data_creator self.config = {} if config is None else config self.epoch = 0 self.verbose = verbose
Example #5
Source File: datasets.py From deepchem with MIT License | 6 votes |
def transform(self, fn, **args): """Construct a new dataset by applying a transformation to every sample in this dataset. The argument is a function that can be called as follows: >> newx, newy, neww = fn(x, y, w) It might be called only once with the whole dataset, or multiple times with different subsets of the data. Each time it is called, it should transform the samples and return the transformed data. Parameters ---------- fn: function A function to apply to each sample in the dataset Returns ------- a newly constructed Dataset object """ newx, newy, neww = fn(self.X, self.y, self.w) return NumpyDataset(newx, newy, neww, self.ids[:])
Example #6
Source File: computations_test.py From federated with Apache License 2.0 | 6 votes |
def test_with_tf_datasets(self): @computations.tf_computation(computation_types.SequenceType(tf.int64)) def consume(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(str(consume.type_signature), '(int64* -> int64)') @computations.tf_computation def produce(): return tf.data.Dataset.range(10) self.assertEqual(str(produce.type_signature), '( -> int64*)') self.assertEqual(consume(produce()), 45) # TODO(b/131363314): The reference executor should support generating and # returning infinite datasets
Example #7
Source File: inputs.py From BERT with Apache License 2.0 | 6 votes |
def dataset_to_stream(dataset, input_name, n_chunks=0, append_targets=False): """Takes a tf.Dataset and creates a numpy stream of ready batches.""" for example in backend.dataset_as_numpy(dataset): inp, out = example[0][input_name], example[1] # Some accelerators don't handle uint8 well, cast to int. if isinstance(inp, np.uint8): inp = inp.astype(np.int32) if isinstance(out, np.uint8): out = out.astype(np.int32) if len(out.shape) > 1 and out.shape[-1] == 1: out = np.squeeze(out, axis=-1) if n_chunks > 0: inp = tuple(np.split(inp, n_chunks, axis=1)) out = tuple(np.split(out, n_chunks, axis=1)) if append_targets: inp = (inp, out) yield inp, out
Example #8
Source File: model.py From graph2gauss with MIT License | 6 votes |
def __dataset_generator(self, hops, scale_terms): """ Generates a set of triplets and associated scaling terms by: 1. Sampling for each node a set of nodes from each of its neighborhoods 2. Forming all implied pairwise constraints Uses tf.Dataset API to perform the sampling in a separate thread for increased speed. Parameters ---------- hops : dict A dictionary where each 1, 2, ... K, neighborhoods are saved as sparse matrices scale_terms : dict The appropriate up-scaling terms to ensure unbiased estimates for each neighbourhood Returns ------- """ def gen(): while True: yield to_triplets(sample_all_hops(hops), scale_terms) dataset = tf.data.Dataset.from_generator(gen, (tf.int32, tf.float32), ([None, 3], [None])) self.triplets, self.scale_terms = dataset.prefetch(1).make_one_shot_iterator().get_next()
Example #9
Source File: data_provider.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def provide_custom_data(image_file_patterns, batch_size, shuffle=True, num_threads=1, patch_size=128): """Provides multiple batches of custom image data. Args: image_file_patterns: A list of glob patterns of image files. batch_size: The number of images in each batch. shuffle: Whether to shuffle the read images. Defaults to True. num_threads: Number of prefetching threads. Defaults to 1. patch_size: Size of the patch to extract from the image. Defaults to 128. Returns: A list of float `Tensor`s with the same size of `image_file_patterns`. Each of the `Tensor` in the list has a shape of [batch_size, patch_size, patch_size, 3] representing a batch of images. As a side effect, the tf.Dataset initializer is added to the tf.GraphKeys.TABLE_INITIALIZERS collection. Raises: ValueError: If image_file_patterns is not a list or tuple. """ datasets = provide_custom_datasets( image_file_patterns, batch_size, shuffle, num_threads, patch_size) tensors = [] for ds in datasets: iterator = ds.make_initializable_iterator() tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) tensors.append(iterator.get_next()) return tensors
Example #10
Source File: data_pipeline.py From ml-on-gcp with Apache License 2.0 | 5 votes |
def make_input_fn(is_training): """Construct training input_fn that uses synthetic data.""" def input_fn(params): """Generated input_fn for the given epoch.""" batch_size = (params["batch_size"] if is_training else params["eval_batch_size"]) num_users = params["num_users"] num_items = params["num_items"] users = tf.random_uniform([batch_size], dtype=tf.int32, minval=0, maxval=num_users) items = tf.random_uniform([batch_size], dtype=tf.int32, minval=0, maxval=num_items) if is_training: valid_point_mask = tf.cast(tf.random_uniform( [batch_size], dtype=tf.int32, minval=0, maxval=2), tf.bool) labels = tf.cast(tf.random_uniform( [batch_size], dtype=tf.int32, minval=0, maxval=2), tf.bool) data = { movielens.USER_COLUMN: users, movielens.ITEM_COLUMN: items, rconst.VALID_POINT_MASK: valid_point_mask, }, labels else: dupe_mask = tf.cast(tf.random_uniform([batch_size], dtype=tf.int32, minval=0, maxval=2), tf.bool) data = { movielens.USER_COLUMN: users, movielens.ITEM_COLUMN: items, rconst.DUPLICATE_MASK: dupe_mask, } dataset = tf.data.Dataset.from_tensors(data).repeat( rconst.SYNTHETIC_BATCHES_PER_EPOCH * params["batches_per_step"]) dataset = dataset.prefetch(32) return dataset return input_fn
Example #11
Source File: datasets.py From spherical-cnn with MIT License | 5 votes |
def from_cached_tfrecords(args): """ Use tf.Dataset, but feeding it using a placeholder w/ the whole dataset. """ # this may seem a bit weird # we take tfrecords but load them into placeholders during training # we found that it loaded faster this way when this was first implemented # letting tf.Dataset loading all simulataneously is conceptually better res, nch = args.input_res, args.nchannels x = tf.placeholder(args.dtype, (None, res, res, nch)) y = tf.placeholder('int64', (None)) dataset = tf.contrib.data.Dataset.from_tensor_slices((x, y)) # inputs are complex numbers # magnitude is ray length # phase is angle between ray and normal # we found that it is best to treat them independently, though dataset = dataset.map(lambda x, y: (tf.concat([tf.abs(x), tf.imag(x/(tf.cast(tf.abs(x), 'complex64') +1e-8))], axis=-1), y)) # we use same batch sizes for train/val/test dataset = dataset.batch(args.train_bsize) iterator = dataset.make_initializable_iterator() fnames = {} for t in ['train', 'test', 'val']: fnames[t] = glob.glob(args.dset_dir + '/{}*.tfrecord'.format(t)) out = {'x': x, 'y': y, 'fnames': fnames} print('loading dataset; number of tfrecords: {}' .format({k: len(v) for k, v in out['fnames'].items()})) return iterator, out
Example #12
Source File: datasets.py From paccmann with MIT License | 5 votes |
def generate_dataset( filepath, buffer_size=int(256e+6), num_parallel_reads=None ): """ Generate a tf.Dataset given a path. Args: - filepath: path to a file or a folder containing data (<string>). - buffer_size: size of the buffer in bytes (<int>). Defaults to 256MB. Returns: A tf.Dataset iterator over file/s in .tfrecords format. """ if os.path.isdir(filepath): filenames = get_sorted_filelist(filepath) else: filenames = [filepath] logger.debug( 'Parsing examples from the following files: {}'.format(filenames) ) return tf.data.TFRecordDataset( filenames, buffer_size=buffer_size, num_parallel_reads=num_parallel_reads )
Example #13
Source File: parse_sdf_utils.py From deep-molecular-massspec with Apache License 2.0 | 5 votes |
def get_dataset_in_one_batch(dataset, total_data_length): """Return all data in tf.Dataset in a single batch.""" # Note that this line may raise some runtime warnings, since in general # composing .prefetch() and .cache() this way could be dropping data. However, # in our use case it is assumed that all of the data in the dataset is # contained in a single batch, so this order of caching and prefetching is # acceptable. dataset = dataset.prefetch(1).cache().repeat() # For downstream usages where we want the entire dataset in one batch we # also want the batch shape to be statically inferrable. Below, we # set that. Note that the only reason the set_shape command will fail # is if the size of the data is not what was provided in # data_info['num_examples']. def _set_static_batch_dimension(data): def _set_static_batch_dimension_for_tensor(tensor): shape = tensor.shape.as_list() shape[0] = total_data_length tensor.set_shape(shape) return tensor return tf.contrib.framework.nest.map_structure( _set_static_batch_dimension_for_tensor, data) return dataset.map(_set_static_batch_dimension)
Example #14
Source File: data_testing_lib.py From task_adaptation with Apache License 2.0 | 5 votes |
def test_base_class(self): """Tests that the dataset wrapper inherits from base.ImageData.""" self.assertIsInstance(self.data_wrapper, base.ImageData, "Dataset class must inherit from `base.ImageData`.")
Example #15
Source File: data_testing_lib.py From task_adaptation with Apache License 2.0 | 5 votes |
def test_dataset_output(self): """Tests that the final tf.Dataset object has expected output shapes.""" batch_size = 2 for split in ("train", "val", "trainval", "test"): tf_data = self.data_wrapper.get_tf_data(split, batch_size) self.assertIsInstance(tf_data.output_shapes, dict) for tensor_name, expected_shape in self.required_tensors_shapes.items(): self.assertIn(tensor_name, tf_data.output_shapes.keys()) expected_shape = [batch_size] + list(expected_shape) actual_shape = tf_data.output_shapes[tensor_name].as_list() self.assertEqual( actual_shape, expected_shape, msg=("Tensor {!r} for split {!r} does not match the expected " "value".format(tensor_name, split)))
Example #16
Source File: data_testing_lib.py From task_adaptation with Apache License 2.0 | 5 votes |
def test_base_class(self): """Tests that the dataset wrapper inherits from base.ImageData.""" self.assertIsInstance(self.data_wrapper, base.ImageTfdsData, "Dataset class must inherit from `base.ImageData`.")
Example #17
Source File: data_provider.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def _provide_custom_dataset(image_file_pattern, batch_size, shuffle=True, num_threads=1, patch_size=128): """Provides batches of custom image data. Args: image_file_pattern: A string of glob pattern of image files. batch_size: The number of images in each batch. shuffle: Whether to shuffle the read images. Defaults to True. num_threads: Number of mapping threads. Defaults to 1. patch_size: Size of the path to extract from the image. Defaults to 128. Returns: A tf.data.Dataset with Tensors of shape [batch_size, patch_size, patch_size, 3] representing a batch of images. Raises: ValueError: If no files match `image_file_pattern`. """ if not tf.gfile.Glob(image_file_pattern): raise ValueError('No file patterns found.') filenames_ds = tf.data.Dataset.list_files(image_file_pattern) bytes_ds = filenames_ds.map(tf.io.read_file, num_parallel_calls=num_threads) images_ds = bytes_ds.map( tf.image.decode_image, num_parallel_calls=num_threads) patches_ds = images_ds.map( lambda img: full_image_to_patch(img, patch_size), num_parallel_calls=num_threads) patches_ds = patches_ds.repeat() if shuffle: patches_ds = patches_ds.shuffle(5 * batch_size) patches_ds = patches_ds.prefetch(5 * batch_size) patches_ds = patches_ds.batch(batch_size) return patches_ds
Example #18
Source File: eval_lib.py From ylg with GNU General Public License v3.0 | 5 votes |
def get_activations_from_dataset(image_ds, num_batches, get_logits=False): """Get Inception activations. Args: image_ds: tf.Dataset for images. num_batches: The number of batches to fetch at a time. get_logits: If `True`, return (logits, pools). Otherwise just return pools. Returns: 1 or 2 Tensors of Inception activations. """ iterator = tf.compat.v1.data.make_one_shot_iterator(image_ds) get_images_fn = iterator.get_next return get_activations(get_images_fn, num_batches, get_logits)
Example #19
Source File: data_pipeline.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def make_input_fn(is_training): """Construct training input_fn that uses synthetic data.""" def input_fn(params): """Generated input_fn for the given epoch.""" batch_size = (params["batch_size"] if is_training else params["eval_batch_size"]) num_users = params["num_users"] num_items = params["num_items"] users = tf.random_uniform([batch_size], dtype=tf.int32, minval=0, maxval=num_users) items = tf.random_uniform([batch_size], dtype=tf.int32, minval=0, maxval=num_items) if is_training: valid_point_mask = tf.cast(tf.random_uniform( [batch_size], dtype=tf.int32, minval=0, maxval=2), tf.bool) labels = tf.cast(tf.random_uniform( [batch_size], dtype=tf.int32, minval=0, maxval=2), tf.bool) data = { movielens.USER_COLUMN: users, movielens.ITEM_COLUMN: items, rconst.VALID_POINT_MASK: valid_point_mask, }, labels else: dupe_mask = tf.cast(tf.random_uniform([batch_size], dtype=tf.int32, minval=0, maxval=2), tf.bool) data = { movielens.USER_COLUMN: users, movielens.ITEM_COLUMN: items, rconst.DUPLICATE_MASK: dupe_mask, } dataset = tf.data.Dataset.from_tensors(data).repeat( rconst.SYNTHETIC_BATCHES_PER_EPOCH * params["batches_per_step"]) dataset = dataset.prefetch(32) return dataset return input_fn
Example #20
Source File: data_loader.py From Distributed-Tensorflow-Template with MIT License | 5 votes |
def __init__(self, config: dict, mode: str) -> None: """ The Dataset will be dependent on the mode (train, eval etc) :param config: global configuration settings :param mode: current training mode (train, test, predict) """ self.config = config self.mode = mode
Example #21
Source File: data_loader.py From Distributed-Tensorflow-Template with MIT License | 5 votes |
def input_fn(self) -> tf.data.Dataset: """ Create a dataset which reads in some data source (e.g. tfrecords, csv etc) """ raise NotImplementedError
Example #22
Source File: train.py From Distributed-Tensorflow-Template with MIT License | 5 votes |
def run(self) -> None: # allow memory usage to me scaled based on usage config = tf.ConfigProto() config.gpu_options.allow_growth = True # get number of steps required for one pass of data steps_pre_epoch = len(self.train) / self.config["train_batch_size"] # save_checkpoints_steps is number of batches before eval run_config = tf.estimator.RunConfig( session_config=config, save_checkpoints_steps=steps_pre_epoch * 10, # number of batches before eval/checkpoint log_step_count_steps=steps_pre_epoch, # number of steps in epoch ) # set output directory run_config = run_config.replace(model_dir=self.config["job_dir"]) # intialise the estimator with your model estimator = tf.estimator.Estimator(model_fn=self.model.model, config=run_config) # create train and eval specs for estimator, it will automatically convert the tf.Dataset into an input_fn train_spec = tf.estimator.TrainSpec( lambda: self.train.input_fn(), max_steps=self.config["num_epochs"] * steps_pre_epoch, ) eval_spec = tf.estimator.EvalSpec(lambda: self.val.input_fn()) # initialise a wrapper to do training and evaluation, this also handles exporting checkpoints/tensorboard info tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) # after training export the final model for use in tensorflow serving self._export_model(estimator, self.config["export_path"]) # get results after training and exporting model self._predict(estimator, self.pred.input_fn)
Example #23
Source File: data_loader.py From Distributed-Tensorflow-Template with MIT License | 5 votes |
def input_fn(self) -> tf.data.Dataset: """ Create a tf.Dataset using tfrecords as inputs, use parallel loading and augmentation using the CPU to reduce bottle necking of operations on the GPU :return: a Dataset function """ dataset = tf.data.TFRecordDataset(self.file_names) # create a parallel parsing function based on number of cpu cores dataset = dataset.map( map_func=self._parse_example, num_parallel_calls=multiprocessing.cpu_count() ) # only shuffle training data if self.mode == "train": # shuffles and repeats a Dataset returning a new permutation for each epoch. with serialised compatibility dataset = dataset.apply( tf.contrib.data.shuffle_and_repeat( buffer_size=len(self) // self.config["train_batch_size"] ) ) else: dataset = dataset.repeat(self.config["num_epochs"]) # create batches of data dataset = dataset.batch(batch_size=self.batch_size) return dataset
Example #24
Source File: data_provider.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def _provide_custom_dataset(image_file_pattern, batch_size, shuffle=True, num_threads=1, patch_size=128): """Provides batches of custom image data. Args: image_file_pattern: A string of glob pattern of image files. batch_size: The number of images in each batch. shuffle: Whether to shuffle the read images. Defaults to True. num_threads: Number of mapping threads. Defaults to 1. patch_size: Size of the path to extract from the image. Defaults to 128. Returns: A tf.data.Dataset with Tensors of shape [batch_size, patch_size, patch_size, 3] representing a batch of images. Raises: ValueError: If no files match `image_file_pattern`. """ if not tf.gfile.Glob(image_file_pattern): raise ValueError('No file patterns found.') filenames_ds = tf.data.Dataset.list_files(image_file_pattern) bytes_ds = filenames_ds.map(tf.io.read_file, num_parallel_calls=num_threads) images_ds = bytes_ds.map( tf.image.decode_image, num_parallel_calls=num_threads) patches_ds = images_ds.map( lambda img: full_image_to_patch(img, patch_size), num_parallel_calls=num_threads) patches_ds = patches_ds.repeat() if shuffle: patches_ds = patches_ds.shuffle(5 * batch_size) patches_ds = patches_ds.prefetch(5 * batch_size) patches_ds = patches_ds.batch(batch_size) return patches_ds
Example #25
Source File: data_provider.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def provide_custom_data(image_file_patterns, batch_size, shuffle=True, num_threads=1, patch_size=128): """Provides multiple batches of custom image data. Args: image_file_patterns: A list of glob patterns of image files. batch_size: The number of images in each batch. shuffle: Whether to shuffle the read images. Defaults to True. num_threads: Number of prefetching threads. Defaults to 1. patch_size: Size of the patch to extract from the image. Defaults to 128. Returns: A list of float `Tensor`s with the same size of `image_file_patterns`. Each of the `Tensor` in the list has a shape of [batch_size, patch_size, patch_size, 3] representing a batch of images. As a side effect, the tf.Dataset initializer is added to the tf.GraphKeys.TABLE_INITIALIZERS collection. Raises: ValueError: If image_file_patterns is not a list or tuple. """ datasets = provide_custom_datasets( image_file_patterns, batch_size, shuffle, num_threads, patch_size) tensors = [] for ds in datasets: iterator = ds.make_initializable_iterator() tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) tensors.append(iterator.get_next()) return tensors
Example #26
Source File: datasets.py From deepchem with MIT License | 5 votes |
def to_dataframe(self): """Construct a pandas DataFrame containing the data from this Dataset. Returns ------- pandas dataframe. If there is only a single feature per datapoint, will have column "X" else will have columns "X1,X2,..." for features. If there is only a single label per datapoint, will have column "y" else will have columns "y1,y2,..." for labels. If there is only a single weight per datapoint will have column "w" else will have columns "w1,w2,...". Will have column "ids" for identifiers. """ X = self.X y = self.y w = self.w ids = self.ids if len(X.shape) == 1 or X.shape[1] == 1: columns = ['X'] else: columns = [f'X{i+1}' for i in range(X.shape[1])] X_df = pd.DataFrame(X, columns=columns) if len(y.shape) == 1 or y.shape[1] == 1: columns = ['y'] else: columns = [f'y{i+1}' for i in range(y.shape[1])] y_df = pd.DataFrame(y, columns=columns) if len(w.shape) == 1 or w.shape[1] == 1: columns = ['w'] else: columns = [f'w{i+1}' for i in range(w.shape[1])] w_df = pd.DataFrame(w, columns=columns) ids_df = pd.DataFrame(ids, columns=['ids']) return pd.concat([X_df, y_df, w_df, ids_df], axis=1, sort=False)
Example #27
Source File: datasets.py From deepchem with MIT License | 5 votes |
def transform(self, fn, **args): """Construct a new dataset by applying a transformation to every sample in this dataset. The argument is a function that can be called as follows: >> newx, newy, neww = fn(x, y, w) It might be called only once with the whole dataset, or multiple times with different subsets of the data. Each time it is called, it should transform the samples and return the transformed data. Parameters ---------- fn: function A function to apply to each sample in the dataset out_dir: string The directory to save the new dataset in. If this is omitted, a temporary directory is created automatically Returns ------- a newly constructed Dataset object """ if 'out_dir' in args: out_dir = args['out_dir'] else: out_dir = tempfile.mkdtemp() tasks = self.get_task_names() def generator(): for shard_num, row in self.metadata_df.iterrows(): X, y, w, ids = self.get_shard(shard_num) newx, newy, neww = fn(X, y, w) yield (newx, newy, neww, ids) return DiskDataset.create_dataset( generator(), data_dir=out_dir, tasks=tasks)
Example #28
Source File: computations_test.py From federated with Apache License 2.0 | 5 votes |
def test_with_sequence_of_pairs(self): pairs = tf.data.Dataset.from_tensor_slices( (list(range(5)), list(range(5, 10)))) @computations.tf_computation def process_pairs(ds): return ds.reduce(0, lambda state, pair: state + pair[0] + pair[1]) self.assertEqual(process_pairs(pairs), 45)
Example #29
Source File: computations_test.py From federated with Apache License 2.0 | 5 votes |
def test_produce_and_consume_infinite_tf_dataset(self): @computations.tf_computation(computation_types.SequenceType(tf.int64)) def consume(ds): # Consume the first 10 elements of the dataset. return ds.take(10).reduce(np.int64(0), lambda x, y: x + y) @computations.tf_computation def produce(): # Produce an infinite dataset. return tf.data.Dataset.range(10).repeat() self.assertEqual(consume(produce()), 45)
Example #30
Source File: computations_test.py From federated with Apache License 2.0 | 5 votes |
def test_consume_infinite_tf_dataset(self): @computations.tf_computation(computation_types.SequenceType(tf.int64)) def consume(ds): # Consume the first 10 elements of the dataset. return ds.take(10).reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(consume(tf.data.Dataset.range(10).repeat()), 45) # TODO(b/131363314): The reference executor should support generating and # returning infinite datasets