Python tensorflow.compat.v2.gather() Examples
The following are 21
code examples of tensorflow.compat.v2.gather().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v2
, or try the search function
.
Example #1
Source File: array_ops.py From trax with Apache License 2.0 | 6 votes |
def take(a, indices, axis=None, out=None, mode='clip'): """out argument is not supported, and default mode is clip.""" if out is not None: raise ValueError('out argument is not supported in take.') if mode not in {'raise', 'clip', 'wrap'}: raise ValueError("Invalid mode '{}' for take".format(mode)) a = asarray(a).data indices = asarray(indices).data if axis is None: a = tf.reshape(a, [-1]) axis = 0 axis_size = tf.shape(a, indices.dtype)[axis] if mode == 'clip': indices = tf.clip_by_value(indices, 0, axis_size-1) elif mode == 'wrap': indices = tf.math.floormod(indices, axis_size) else: raise ValueError("The 'raise' mode to take is not supported.") return utils.tensor_to_ndarray(tf.gather(a, indices, axis=axis))
Example #2
Source File: vector_hull_white.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def _compute_yt(self, t, mr_t, sigma_t): """Computes y(t) as described in [1], section 10.1.6.1.""" t = tf.repeat(tf.expand_dims(t, axis=0), self._dim, axis=0) time_index = tf.searchsorted(self._jump_locations, t) y_between_vol_knots = self._y_integral( self._padded_knots, self._jump_locations, self._jump_values_vol, self._jump_values_mr) y_at_vol_knots = tf.concat( [self._zero_padding, _cumsum_using_matvec(y_between_vol_knots)], axis=1) vn = tf.concat( [self._zero_padding, self._jump_locations], axis=1) y_t = self._y_integral( tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t) y_t = y_t + tf.gather(y_at_vol_knots, time_index, batch_dims=1) return tf.math.exp(-2 * mr_t * t) * y_t
Example #3
Source File: vector_hull_white.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def _conditional_variance_x(self, t, mr_t, sigma_t): """Computes the variance of x(t), see [1], Eq. 10.41.""" t = tf.repeat(tf.expand_dims(t, axis=0), self._dim, axis=0) var_x_between_vol_knots = self._variance_int(self._padded_knots, self._jump_locations, self._jump_values_vol, self._jump_values_mr) varx_at_vol_knots = tf.concat( [self._zero_padding, _cumsum_using_matvec(var_x_between_vol_knots)], axis=1) time_index = tf.searchsorted(self._jump_locations, t) vn = tf.concat( [self._zero_padding, self._jump_locations], axis=1) var_x_t = self._variance_int( tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t) var_x_t = var_x_t + tf.gather(varx_at_vol_knots, time_index, batch_dims=1) var_x_t = (var_x_t[:, 1:] - var_x_t[:, :-1]) * tf.math.exp( -2 * tf.broadcast_to(mr_t, t.shape)[:, 1:] * t[:, 1:]) return var_x_t
Example #4
Source File: vector_hull_white.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _conditional_mean_x(self, t, mr_t, sigma_t): """Computes the drift term in [1], Eq. 10.39.""" t = tf.repeat(tf.expand_dims(t, axis=0), self._dim, axis=0) time_index = tf.searchsorted(self._jump_locations, t) vn = tf.concat([self._zero_padding, self._jump_locations], axis=1) y_between_vol_knots = self._y_integral(self._padded_knots, self._jump_locations, self._jump_values_vol, self._jump_values_mr) y_at_vol_knots = tf.concat( [self._zero_padding, _cumsum_using_matvec(y_between_vol_knots)], axis=1) ex_between_vol_knots = self._ex_integral(self._padded_knots, self._jump_locations, self._jump_values_vol, self._jump_values_mr, y_at_vol_knots[:, :-1]) ex_at_vol_knots = tf.concat( [self._zero_padding, _cumsum_using_matvec(ex_between_vol_knots)], axis=1) c = tf.gather(y_at_vol_knots, time_index, batch_dims=1) exp_x_t = self._ex_integral( tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t, c) exp_x_t = exp_x_t + tf.gather(ex_at_vol_knots, time_index, batch_dims=1) exp_x_t = (exp_x_t[:, 1:] - exp_x_t[:, :-1]) * tf.math.exp( -tf.broadcast_to(mr_t, t.shape)[:, 1:] * t[:, 1:]) return exp_x_t
Example #5
Source File: feature_column_v2_test.py From hub with Apache License 2.0 | 5 votes |
def __call__(self, text_tensor): indices_tensor = self.table.lookup(text_tensor) embedding_tensor = tf.gather(self.weights, indices_tensor) return dict( outputs=embedding_tensor) if self._returns_dict else embedding_tensor
Example #6
Source File: uniform_noise.py From compression with Apache License 2.0 | 5 votes |
def _quantization_offset(self): # Picks the "peakiest" of the component quantization offsets. offsets = helpers.quantization_offset(self.components_distribution) rank = self.batch_shape.rank transposed_offsets = tf.transpose(offsets, [rank] + list(range(rank))) component = tf.argmax(self.log_prob(transposed_offsets), axis=0) return tf.gather(offsets, component, axis=-1, batch_dims=rank)
Example #7
Source File: date_tensor.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _num_days_in_month(month, year): """Returns number of days in a given month of a given year.""" days_in_months = tf.constant(_DAYS_IN_MONTHS_COMBINED, tf.int32) is_leap = date_utils.is_leap_year(year) return tf.gather(days_in_months, month + 12 * tf.dtypes.cast(is_leap, tf.int32))
Example #8
Source File: date_tensor.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def day_of_year(self): """Calculates the number of days since the beginning of the year. Returns: Tensor of int32 type with elements in range [1, 366]. January 1st yields "1". #### Example ```python dt = tff.datetime.dates_from_tuples([(2019, 1, 25), (2020, 3, 2)]) dt.day_of_year() # [25, 62] ``` """ if self._day_of_year is None: cumul_days_in_month_nonleap = tf.math.cumsum( _DAYS_IN_MONTHS_NON_LEAP, exclusive=True) cumul_days_in_month_leap = tf.math.cumsum( _DAYS_IN_MONTHS_LEAP, exclusive=True) days_before_month_non_leap = tf.gather(cumul_days_in_month_nonleap, self.month() - 1) days_before_month_leap = tf.gather(cumul_days_in_month_leap, self.month() - 1) days_before_month = tf.where( date_utils.is_leap_year(self.year()), days_before_month_leap, days_before_month_non_leap) self._day_of_year = days_before_month + self.day() return self._day_of_year
Example #9
Source File: bounded_holiday_calendar.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _gather(self, table, indices): table_size = self._calendar_size + 2 assert1 = tf.debugging.assert_greater_equal( indices, 0, message=_OUT_OF_BOUNDS_MSG) assert2 = tf.debugging.assert_less( indices, table_size, message=_OUT_OF_BOUNDS_MSG) with tf.control_dependencies([assert1, assert2]): return tf.gather(table, indices)
Example #10
Source File: bounded_holiday_calendar.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _compute_is_bus_day_table(self): """Computes and caches "is business day" table.""" if self._table_cache.is_bus_day is not None: return self._table_cache.is_bus_day with tf.init_scope(): ordinals = tf.range(self._ordinal_offset, self._ordinal_offset + self._calendar_size) # Apply weekend mask week_days = (ordinals - 1) % 7 is_holiday = tf.gather(self._weekend_mask, week_days) # Apply holidays if self._holidays is not None: indices = self._holidays.ordinal() - self._ordinal_offset ones_at_indices = tf.scatter_nd( tf.expand_dims(indices, axis=-1), tf.ones_like(indices), is_holiday.shape) is_holiday = tf.bitwise.bitwise_or(is_holiday, ones_at_indices) # Add a business day at the beginning and at the end, i.e. at 31 Dec of # start_year-1 and at 1 Jan of end_year+1. This trick is to avoid dealing # with special cases on boundaries. # For example, for Following and Preceding conventions we'd need a special # value that means "unknown" in the tables. More complicated conventions # then combine the Following and Preceding tables, and would need special # treatment of the "unknown" values. # With these "fake" business days, all computations are automatically # correct, unless we land on those extra days - for this reason we add # assertions in all API calls before returning. is_bus_day_table = tf.concat([[1], 1 - is_holiday, [1]], axis=0) self._table_cache.is_bus_day = is_bus_day_table return is_bus_day_table
Example #11
Source File: vector_hull_white.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _prepare_grid(times, *params): """Prepares grid of times for path generation. Args: times: Rank 1 `Tensor` of increasing positive real values. The times at which the path points are to be evaluated. *params: Parameters of the Heston model. Either scalar `Tensor`s of the same `dtype` or instances of `PiecewiseConstantFunc`. Returns: Tuple `(all_times, mask)`. `all_times` is a 1-D real `Tensor` containing all points from 'times`, the uniform grid of points between `[0, times[-1]]` with grid size equal to `time_step`, and jump locations of piecewise constant parameters The `Tensor` is sorted in ascending order and may contain duplicates. `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing which elements of 'all_times' correspond to THE values from `times`. Guarantees that times[0]=0 and mask[0]=False. """ additional_times = [] for param in params: if hasattr(param, 'is_piecewise_constant'): if param.is_piecewise_constant: # Flatten all jump locations additional_times.append(tf.reshape(param.jump_locations(), [-1])) zeros = tf.constant([0], dtype=times.dtype) all_times = tf.concat([zeros] + [times] + additional_times, axis=0) additional_times_mask = [ tf.zeros_like(times, dtype=tf.bool) for times in additional_times] mask = tf.concat([ tf.cast(zeros, dtype=tf.bool), tf.ones_like(times, dtype=tf.bool) ] + additional_times_mask, axis=0) perm = tf.argsort(all_times, stable=True) all_times = tf.gather(all_times, perm) mask = tf.gather(mask, perm) return all_times, mask
Example #12
Source File: heston_model.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _prepare_grid(times, time_step, dtype, *params): """Prepares grid of times for path generation. Args: times: Rank 1 `Tensor` of increasing positive real values. The times at which the path points are to be evaluated. time_step: Rank 0 real `Tensor`. Maximal distance between points in resulting grid. dtype: `tf.Dtype` of the input and output `Tensor`s. *params: Parameters of the Heston model. Either scalar `Tensor`s of the same `dtype` or instances of `PiecewiseConstantFunc`. Returns: Tuple `(all_times, mask)`. `all_times` is a 1-D real `Tensor` containing all points from 'times`, the uniform grid of points between `[0, times[-1]]` with grid size equal to `time_step`, and jump locations of piecewise constant parameters The `Tensor` is sorted in ascending order and may contain duplicates. `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing which elements of 'all_times' correspond to THE values from `times`. Guarantees that times[0]=0 and mask[0]=False. """ grid = tf.range(0.0, times[-1], time_step, dtype=dtype) additional_times = [] for param in params: if isinstance(param, piecewise.PiecewiseConstantFunc): additional_times.append(param.jump_locations()) all_times = tf.concat([grid, times] + additional_times, axis=0) additional_times_mask = [ tf.zeros_like(times, dtype=tf.bool) for times in additional_times] mask = tf.concat([ tf.zeros_like(grid, dtype=tf.bool), tf.ones_like(times, dtype=tf.bool) ] + additional_times_mask, axis=0) perm = tf.argsort(all_times, stable=True) all_times = tf.gather(all_times, perm) mask = tf.gather(mask, perm) return all_times, mask
Example #13
Source File: euler_sampling.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _prepare_grid(*, times, time_step, dtype): """Prepares grid of times for path generation. Args: times: Rank 1 `Tensor` of increasing positive real values. The times at which the path points are to be evaluated. time_step: Rank 0 real `Tensor`. Maximal distance between points in resulting grid. dtype: `tf.Dtype` of the input and output `Tensor`s. Returns: Tuple `(all_times, mask, time_points)`. `all_times` is a 1-D real `Tensor` containing all points from 'times` and the uniform grid of points between `[0, times[-1]]` with grid size equal to `time_step`. The `Tensor` is sorted in ascending order and may contain duplicates. `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing which elements of 'all_times' correspond to THE values from `times`. Guarantees that times[0]=0 and mask[0]=False. `time_indices`. An integer `Tensor` of the same shape as `times` indicating `times` indices in `all_times`. """ grid = tf.range(0.0, times[-1], time_step, dtype=dtype) all_times = tf.concat([grid, times], axis=0) mask = tf.concat([ tf.zeros_like(grid, dtype=tf.bool), tf.ones_like(times, dtype=tf.bool) ], axis=0) perm = tf.argsort(all_times, stable=True) all_times = tf.gather(all_times, perm) # Remove duplicate points all_times = tf.unique(all_times).y time_indices = tf.searchsorted(all_times, times) mask = tf.gather(mask, perm) return all_times, mask, time_indices
Example #14
Source File: ito_process.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _prepare_grid(self, times, grid_step): """Prepares grid of times for path generation. Args: times: Rank 1 `Tensor` of increasing positive real values. The times at which the path points are to be evaluated. grid_step: Rank 0 real `Tensor`. Maximal distance between points in resulting grid. Returns: Tuple `(all_times, mask)`. `all_times` is 1-D real `Tensor` containing all points from 'times` and whose intervals are at most `grid_step`. `mask` is a boolean 1-D tensor of the same shape as 'all_times', showing which elements of 'all_times' correspond to values from `times`. Guarantees that times[0]=0 and grid_step[0]=False. 'all_times` is sorted ascending and may contain duplicates. """ grid = tf.range(0.0, times[-1], grid_step, dtype=self._dtype) all_times = tf.concat([grid, times], axis=0) mask = tf.concat([ tf.zeros_like(grid, dtype=tf.bool), tf.ones_like(times, dtype=tf.bool) ], axis=0) perm = tf.argsort(all_times, stable=True) all_times = tf.gather(all_times, perm) mask = tf.gather(mask, perm) return all_times, mask
Example #15
Source File: stateless.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def stateless_random_shuffle(input_tensor, seed, name=None): """Produces stateless random shuffle of the 1st dimension of an input Tensor. This is a stateless version of `tf.random_shuffle`. If run twice with the same seed, produces the same result. Example ```python identity_shuffle = tf.range(100) random_shuffle = stateless_random_shuffle(identity_shuffle, seed=(42, 2)) ``` Args: input_tensor: float32, float64, int32 or int64 1-D Tensor. seed: int32 or int64 Tensor of shape [2]. name: Python `str` name prefixed to ops created by this function. Returns: A Tensor of the same shape and dtype as `input_tensor`. """ with tf.compat.v1.name_scope(name, default_name='stateless_random_shuffle', values=[input_tensor, seed]): input_tensor = tf.convert_to_tensor(input_tensor, name='input_tensor') seed = tf.convert_to_tensor(seed, name='random_seed') uniforms = tf.random.stateless_uniform( shape=[tf.shape(input_tensor)[0]], seed=seed, dtype=tf.float64) return tf.gather(input_tensor, tf.argsort(uniforms, stable=True, axis=0))
Example #16
Source File: halton_impl.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _randomize(coeffs, radixes, seed, perms=None): """Applies the Owen (2017) randomization to the coefficients.""" given_dtype = coeffs.dtype coeffs = tf.cast(coeffs, dtype=tf.int32) num_coeffs = _NUM_COEFFS_BY_DTYPE[given_dtype] radixes = tf.reshape(tf.cast(radixes, dtype=tf.int32), shape=[-1]) if perms is None: perms = _get_permutations(num_coeffs, radixes, seed) perms = tf.reshape(perms, shape=[-1]) radix_sum = tf.reduce_sum(input_tensor=radixes) radix_offsets = tf.reshape(tf.cumsum(radixes, exclusive=True), shape=[-1, 1]) offsets = radix_offsets + tf.range(num_coeffs) * radix_sum permuted_coeffs = tf.gather(perms, coeffs + offsets) return tf.cast(permuted_coeffs, dtype=given_dtype), perms
Example #17
Source File: extensions.py From trax with Apache License 2.0 | 5 votes |
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs): del args, kwargs # TODO(nareshmodi): Consider a collective op to gather the tensors from the # various devices for performance reasons. return tf.stack(value.tensors)
Example #18
Source File: extensions.py From trax with Apache License 2.0 | 4 votes |
def sort_key_val(keys, values, dimension=-1): """Sorts keys along a dimension and applies same permutation to values. Args: keys: an array. The dtype must be comparable numbers (integers and reals). values: an array, with the same shape of `keys`. dimension: an `int`. The dimension along which to sort. Returns: Permuted keys and values. """ keys = tf_np.asarray(keys) values = tf_np.asarray(values) rank = keys.data.shape.ndims if rank is None: rank = values.data.shape.ndims if rank is None: # We need to know the rank because tf.gather requires batch_dims to be `int` raise ValueError("The rank of either keys or values must be known, but " "both are unknown (i.e. their shapes are both None).") if dimension in (-1, rank - 1): def maybe_swapaxes(a): return a else: def maybe_swapaxes(a): return tf_np.swapaxes(a, dimension, -1) # We need to swap axes because tf.gather (and tf.gather_nd) supports # batch_dims on the left but not on the right. # TODO(wangpeng): Investigate whether we should do swapaxes or moveaxis. keys = maybe_swapaxes(keys) values = maybe_swapaxes(values) idxs = tf_np.argsort(keys) idxs = idxs.data # Using tf.gather rather than np.take because the former supports batch_dims def gather(a): return tf_np.asarray(tf.gather(a.data, idxs, batch_dims=rank - 1)) keys = gather(keys) values = gather(values) keys = maybe_swapaxes(keys) values = maybe_swapaxes(values) return keys, values # Use int64 instead of int32 to avoid TF's "int32 problem"
Example #19
Source File: holiday_utils.py From tf-quant-finance with Apache License 2.0 | 4 votes |
def _week_day_mappers(weekend_mask): """Creates functions to map from ordinals to week days and inverse. Creates functions to map from ordinal space (i.e. days since 31 Dec 0) to week days. The function assigns the value of 0 to the first non weekend day in the week starting on Sunday, 31 Dec 1 through to Saturday, 6 Jan 1 and the value assigned to each successive work day is incremented by 1. For a day that is not a week day, this count is not incremented from the previous week day (hence, multiple ordinal days may have the same week day value). Args: weekend_mask: A bool `Tensor` of length 7 or None. The weekend mask. Returns: A tuple of callables. `forward`: Takes one `Tensor` argument containing ordinals and returns a tuple of two `Tensor`s of the same shape as the input. The first `Tensor` is of type `int32` and contains the week day value. The second is a bool `Tensor` indicating whether the supplied ordinal was a weekend day (i.e. True where the day is a weekend day and False otherwise). `backward`: Takes one int32 `Tensor` argument containing week day values and returns an int32 `Tensor` containing ordinals for those week days. """ if weekend_mask is None: default_forward = lambda x: (x, tf.zeros_like(x, dtype=tf.bool)) identity = lambda x: x return default_forward, identity weekend_mask = tf.convert_to_tensor(weekend_mask, dtype=tf.bool) weekend_mask = tf.roll(weekend_mask, -_DAYOFWEEK_0, axis=0) weekday_mask = tf.logical_not(weekend_mask) weekday_offsets = tf.cumsum(tf.cast(weekday_mask, dtype=tf.int32)) num_workdays = weekday_offsets[-1] weekday_offsets -= 1 # Adjust the first workday to index 0. ordinal_offsets = tf.convert_to_tensor([0, 1, 2, 3, 4, 5, 6], dtype=tf.int32) ordinal_offsets = ordinal_offsets[weekday_mask] def forward(ordinals): """Adjusts the ordinals by removing the number of weekend days so far.""" mod, remainder = ordinals // 7, ordinals % 7 weekday_values = mod * num_workdays + tf.gather(weekday_offsets, remainder) is_weekday = tf.gather(weekday_mask, remainder) return weekday_values, is_weekday def backward(weekday_values): """Converts from weekend adjusted values to ordinals.""" return ((weekday_values // num_workdays) * 7 + tf.gather(ordinal_offsets, weekday_values % num_workdays)) return forward, backward
Example #20
Source File: date_tensor.py From tf-quant-finance with Apache License 2.0 | 4 votes |
def from_year_month_day(year, month, day, validate=True): """Creates DateTensor from tensors of years, months and days. Args: year: Tensor of int32 type. Elements should be positive. month: Tensor of int32 type of same shape as `year`. Elements should be in range `[1, 12]`. day: Tensor of int32 type of same shape as `year`. Elements should be in range `[1, 31]` and represent valid dates together with corresponding elements of `month` and `year` Tensors. validate: Whether to validate the dates. Returns: DateTensor object. #### Example ```python year = tf.constant([2015, 2017], dtype=tf.int32) month = tf.constant([4, 12], dtype=tf.int32) day = tf.constant([15, 30], dtype=tf.int32) date_tensor = tff.datetime.dates_from_year_month_day(year, month, day) ``` """ year = tf.convert_to_tensor(year, tf.int32) month = tf.convert_to_tensor(month, tf.int32) day = tf.convert_to_tensor(day, tf.int32) control_deps = [] if validate: control_deps.append( tf.debugging.assert_positive(year, message="Year must be positive.")) control_deps.append( tf.debugging.assert_greater_equal( month, constants.Month.JANUARY.value, message=f"Month must be >= {constants.Month.JANUARY.value}")) control_deps.append( tf.debugging.assert_less_equal( month, constants.Month.DECEMBER.value, message="Month must be <= {constants.Month.JANUARY.value}")) control_deps.append( tf.debugging.assert_positive(day, message="Day must be positive.")) is_leap = date_utils.is_leap_year(year) days_in_months = tf.constant(_DAYS_IN_MONTHS_COMBINED, tf.int32) max_days = tf.gather(days_in_months, month + 12 * tf.dtypes.cast(is_leap, np.int32)) control_deps.append( tf.debugging.assert_less_equal( day, max_days, message="Invalid day-month pairing.")) with tf.compat.v1.control_dependencies(control_deps): # Ensure years, months, days themselves are under control_deps. year = tf.identity(year) month = tf.identity(month) day = tf.identity(day) with tf.compat.v1.control_dependencies(control_deps): ordinal = date_utils.year_month_day_to_ordinal(year, month, day) return DateTensor(ordinal, year, month, day)
Example #21
Source File: network.py From ranking with Apache License 2.0 | 4 votes |
def compute_logits(self, context_features=None, example_features=None, training=None, mask=None): """Scores context and examples to return a score per document. Args: context_features: (dict) context feature names to 2D tensors of shape [batch_size, feature_dims]. example_features: (dict) example feature names to 3D tensors of shape [batch_size, list_size, feature_dims]. training: (bool) whether in train or inference mode. mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which is True for a valid example and False for invalid one. If mask is None, all entries are valid. Returns: (tf.Tensor) A score tensor of shape [batch_size, list_size]. """ tensor = next(six.itervalues(example_features)) batch_size = tf.shape(tensor)[0] list_size = tf.shape(tensor)[1] if mask is None: mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool) nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask) # Expand query features to be of [batch_size, list_size, ...]. large_batch_context_features = {} for name, tensor in six.iteritems(context_features): x = tf.expand_dims(input=tensor, axis=1) x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1) large_batch_context_features[name] = utils.reshape_first_ndims( x, 2, [batch_size * list_size]) large_batch_example_features = {} for name, tensor in six.iteritems(example_features): # Replace invalid example features with valid ones. padded_tensor = tf.gather_nd(tensor, nd_indices) large_batch_example_features[name] = utils.reshape_first_ndims( padded_tensor, 2, [batch_size * list_size]) # Get scores for large batch. scores = self.score( context_features=large_batch_context_features, example_features=large_batch_example_features, training=training) logits = tf.reshape( scores, shape=[batch_size, list_size]) # Apply nd_mask to zero out invalid entries. logits = tf.where(nd_mask, logits, tf.zeros_like(logits)) return logits