Python tensorflow.compat.v2.gather() Examples

The following are 21 code examples of tensorflow.compat.v2.gather(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v2 , or try the search function .
Example #1
Source File: array_ops.py    From trax with Apache License 2.0 6 votes vote down vote up
def take(a, indices, axis=None, out=None, mode='clip'):
  """out argument is not supported, and default mode is clip."""
  if out is not None:
    raise ValueError('out argument is not supported in take.')

  if mode not in {'raise', 'clip', 'wrap'}:
    raise ValueError("Invalid mode '{}' for take".format(mode))

  a = asarray(a).data
  indices = asarray(indices).data

  if axis is None:
    a = tf.reshape(a, [-1])
    axis = 0

  axis_size = tf.shape(a, indices.dtype)[axis]
  if mode == 'clip':
    indices = tf.clip_by_value(indices, 0, axis_size-1)
  elif mode == 'wrap':
    indices = tf.math.floormod(indices, axis_size)
  else:
    raise ValueError("The 'raise' mode to take is not supported.")

  return utils.tensor_to_ndarray(tf.gather(a, indices, axis=axis)) 
Example #2
Source File: vector_hull_white.py    From tf-quant-finance with Apache License 2.0 6 votes vote down vote up
def _compute_yt(self, t, mr_t, sigma_t):
    """Computes y(t) as described in [1], section 10.1.6.1."""
    t = tf.repeat(tf.expand_dims(t, axis=0), self._dim, axis=0)
    time_index = tf.searchsorted(self._jump_locations, t)
    y_between_vol_knots = self._y_integral(
        self._padded_knots, self._jump_locations, self._jump_values_vol,
        self._jump_values_mr)
    y_at_vol_knots = tf.concat(
        [self._zero_padding,
         _cumsum_using_matvec(y_between_vol_knots)], axis=1)

    vn = tf.concat(
        [self._zero_padding, self._jump_locations], axis=1)
    y_t = self._y_integral(
        tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t)
    y_t = y_t + tf.gather(y_at_vol_knots, time_index, batch_dims=1)
    return tf.math.exp(-2 * mr_t * t) * y_t 
Example #3
Source File: vector_hull_white.py    From tf-quant-finance with Apache License 2.0 6 votes vote down vote up
def _conditional_variance_x(self, t, mr_t, sigma_t):
    """Computes the variance of x(t), see [1], Eq. 10.41."""
    t = tf.repeat(tf.expand_dims(t, axis=0), self._dim, axis=0)
    var_x_between_vol_knots = self._variance_int(self._padded_knots,
                                                 self._jump_locations,
                                                 self._jump_values_vol,
                                                 self._jump_values_mr)
    varx_at_vol_knots = tf.concat(
        [self._zero_padding,
         _cumsum_using_matvec(var_x_between_vol_knots)],
        axis=1)

    time_index = tf.searchsorted(self._jump_locations, t)
    vn = tf.concat(
        [self._zero_padding,
         self._jump_locations], axis=1)

    var_x_t = self._variance_int(
        tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t)
    var_x_t = var_x_t + tf.gather(varx_at_vol_knots, time_index, batch_dims=1)

    var_x_t = (var_x_t[:, 1:] - var_x_t[:, :-1]) * tf.math.exp(
        -2 * tf.broadcast_to(mr_t, t.shape)[:, 1:] * t[:, 1:])
    return var_x_t 
Example #4
Source File: vector_hull_white.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _conditional_mean_x(self, t, mr_t, sigma_t):
    """Computes the drift term in [1], Eq. 10.39."""
    t = tf.repeat(tf.expand_dims(t, axis=0), self._dim, axis=0)
    time_index = tf.searchsorted(self._jump_locations, t)
    vn = tf.concat([self._zero_padding, self._jump_locations], axis=1)
    y_between_vol_knots = self._y_integral(self._padded_knots,
                                           self._jump_locations,
                                           self._jump_values_vol,
                                           self._jump_values_mr)

    y_at_vol_knots = tf.concat(
        [self._zero_padding,
         _cumsum_using_matvec(y_between_vol_knots)], axis=1)

    ex_between_vol_knots = self._ex_integral(self._padded_knots,
                                             self._jump_locations,
                                             self._jump_values_vol,
                                             self._jump_values_mr,
                                             y_at_vol_knots[:, :-1])

    ex_at_vol_knots = tf.concat(
        [self._zero_padding,
         _cumsum_using_matvec(ex_between_vol_knots)], axis=1)

    c = tf.gather(y_at_vol_knots, time_index, batch_dims=1)
    exp_x_t = self._ex_integral(
        tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t, c)
    exp_x_t = exp_x_t + tf.gather(ex_at_vol_knots, time_index, batch_dims=1)
    exp_x_t = (exp_x_t[:, 1:] - exp_x_t[:, :-1]) * tf.math.exp(
        -tf.broadcast_to(mr_t, t.shape)[:, 1:] * t[:, 1:])
    return exp_x_t 
Example #5
Source File: feature_column_v2_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def __call__(self, text_tensor):
    indices_tensor = self.table.lookup(text_tensor)
    embedding_tensor = tf.gather(self.weights, indices_tensor)
    return dict(
        outputs=embedding_tensor) if self._returns_dict else embedding_tensor 
Example #6
Source File: uniform_noise.py    From compression with Apache License 2.0 5 votes vote down vote up
def _quantization_offset(self):
    # Picks the "peakiest" of the component quantization offsets.
    offsets = helpers.quantization_offset(self.components_distribution)
    rank = self.batch_shape.rank
    transposed_offsets = tf.transpose(offsets, [rank] + list(range(rank)))
    component = tf.argmax(self.log_prob(transposed_offsets), axis=0)
    return tf.gather(offsets, component, axis=-1, batch_dims=rank) 
Example #7
Source File: date_tensor.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _num_days_in_month(month, year):
  """Returns number of days in a given month of a given year."""
  days_in_months = tf.constant(_DAYS_IN_MONTHS_COMBINED, tf.int32)
  is_leap = date_utils.is_leap_year(year)
  return tf.gather(days_in_months,
                   month + 12 * tf.dtypes.cast(is_leap, tf.int32)) 
Example #8
Source File: date_tensor.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def day_of_year(self):
    """Calculates the number of days since the beginning of the year.

    Returns:
      Tensor of int32 type with elements in range [1, 366]. January 1st yields
      "1".

    #### Example

    ```python
    dt = tff.datetime.dates_from_tuples([(2019, 1, 25), (2020, 3, 2)])
    dt.day_of_year()  # [25, 62]
    ```
    """
    if self._day_of_year is None:
      cumul_days_in_month_nonleap = tf.math.cumsum(
          _DAYS_IN_MONTHS_NON_LEAP, exclusive=True)
      cumul_days_in_month_leap = tf.math.cumsum(
          _DAYS_IN_MONTHS_LEAP, exclusive=True)
      days_before_month_non_leap = tf.gather(cumul_days_in_month_nonleap,
                                             self.month() - 1)
      days_before_month_leap = tf.gather(cumul_days_in_month_leap,
                                         self.month() - 1)
      days_before_month = tf.where(
          date_utils.is_leap_year(self.year()), days_before_month_leap,
          days_before_month_non_leap)
      self._day_of_year = days_before_month + self.day()
    return self._day_of_year 
Example #9
Source File: bounded_holiday_calendar.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _gather(self, table, indices):
    table_size = self._calendar_size + 2
    assert1 = tf.debugging.assert_greater_equal(
        indices, 0, message=_OUT_OF_BOUNDS_MSG)
    assert2 = tf.debugging.assert_less(
        indices, table_size, message=_OUT_OF_BOUNDS_MSG)
    with tf.control_dependencies([assert1, assert2]):
      return tf.gather(table, indices) 
Example #10
Source File: bounded_holiday_calendar.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _compute_is_bus_day_table(self):
    """Computes and caches "is business day" table."""
    if self._table_cache.is_bus_day is not None:
      return self._table_cache.is_bus_day

    with tf.init_scope():
      ordinals = tf.range(self._ordinal_offset,
                          self._ordinal_offset + self._calendar_size)
      # Apply weekend mask
      week_days = (ordinals - 1) % 7
      is_holiday = tf.gather(self._weekend_mask, week_days)

      # Apply holidays
      if self._holidays is not None:
        indices = self._holidays.ordinal() - self._ordinal_offset
        ones_at_indices = tf.scatter_nd(
            tf.expand_dims(indices, axis=-1), tf.ones_like(indices),
            is_holiday.shape)
        is_holiday = tf.bitwise.bitwise_or(is_holiday, ones_at_indices)

      # Add a business day at the beginning and at the end, i.e. at 31 Dec of
      # start_year-1 and at 1 Jan of end_year+1. This trick is to avoid dealing
      # with special cases on boundaries.
      # For example, for Following and Preceding conventions we'd need a special
      # value that means "unknown" in the tables. More complicated conventions
      # then combine the Following and Preceding tables, and would need special
      # treatment of the "unknown" values.
      # With these "fake" business days, all computations are automatically
      # correct, unless we land on those extra days - for this reason we add
      # assertions in all API calls before returning.
      is_bus_day_table = tf.concat([[1], 1 - is_holiday, [1]], axis=0)
      self._table_cache.is_bus_day = is_bus_day_table
    return is_bus_day_table 
Example #11
Source File: vector_hull_white.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _prepare_grid(times, *params):
  """Prepares grid of times for path generation.

  Args:
    times:  Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    *params: Parameters of the Heston model. Either scalar `Tensor`s of the
      same `dtype` or instances of `PiecewiseConstantFunc`.

  Returns:
    Tuple `(all_times, mask)`.
    `all_times` is a 1-D real `Tensor` containing all points from 'times`, the
    uniform grid of points between `[0, times[-1]]` with grid size equal to
    `time_step`, and jump locations of piecewise constant parameters The
    `Tensor` is sorted in ascending order and may contain duplicates.
    `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing
    which elements of 'all_times' correspond to THE values from `times`.
    Guarantees that times[0]=0 and mask[0]=False.
  """
  additional_times = []
  for param in params:
    if hasattr(param, 'is_piecewise_constant'):
      if param.is_piecewise_constant:
        # Flatten all jump locations
        additional_times.append(tf.reshape(param.jump_locations(), [-1]))
  zeros = tf.constant([0], dtype=times.dtype)
  all_times = tf.concat([zeros] + [times] + additional_times, axis=0)
  additional_times_mask = [
      tf.zeros_like(times, dtype=tf.bool) for times in additional_times]
  mask = tf.concat([
      tf.cast(zeros, dtype=tf.bool),
      tf.ones_like(times, dtype=tf.bool)
  ] + additional_times_mask, axis=0)
  perm = tf.argsort(all_times, stable=True)
  all_times = tf.gather(all_times, perm)
  mask = tf.gather(mask, perm)
  return all_times, mask 
Example #12
Source File: heston_model.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _prepare_grid(times, time_step, dtype, *params):
  """Prepares grid of times for path generation.

  Args:
    times:  Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    time_step: Rank 0 real `Tensor`. Maximal distance between points in
      resulting grid.
    dtype: `tf.Dtype` of the input and output `Tensor`s.
    *params: Parameters of the Heston model. Either scalar `Tensor`s of the
      same `dtype` or instances of `PiecewiseConstantFunc`.

  Returns:
    Tuple `(all_times, mask)`.
    `all_times` is a 1-D real `Tensor` containing all points from 'times`, the
    uniform grid of points between `[0, times[-1]]` with grid size equal to
    `time_step`, and jump locations of piecewise constant parameters The
    `Tensor` is sorted in ascending order and may contain duplicates.
    `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing
    which elements of 'all_times' correspond to THE values from `times`.
    Guarantees that times[0]=0 and mask[0]=False.
  """
  grid = tf.range(0.0, times[-1], time_step, dtype=dtype)
  additional_times = []
  for param in params:
    if isinstance(param, piecewise.PiecewiseConstantFunc):
      additional_times.append(param.jump_locations())
  all_times = tf.concat([grid, times] + additional_times, axis=0)
  additional_times_mask = [
      tf.zeros_like(times, dtype=tf.bool) for times in additional_times]
  mask = tf.concat([
      tf.zeros_like(grid, dtype=tf.bool),
      tf.ones_like(times, dtype=tf.bool)
  ] + additional_times_mask, axis=0)
  perm = tf.argsort(all_times, stable=True)
  all_times = tf.gather(all_times, perm)
  mask = tf.gather(mask, perm)
  return all_times, mask 
Example #13
Source File: euler_sampling.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _prepare_grid(*, times, time_step, dtype):
  """Prepares grid of times for path generation.

  Args:
    times:  Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    time_step: Rank 0 real `Tensor`. Maximal distance between points in
      resulting grid.
    dtype: `tf.Dtype` of the input and output `Tensor`s.

  Returns:
    Tuple `(all_times, mask, time_points)`.
    `all_times` is a 1-D real `Tensor` containing all points from 'times` and
    the uniform grid of points between `[0, times[-1]]` with grid size equal to
    `time_step`. The `Tensor` is sorted in ascending order and may contain
    duplicates.
    `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing
    which elements of 'all_times' correspond to THE values from `times`.
    Guarantees that times[0]=0 and mask[0]=False.
    `time_indices`. An integer `Tensor` of the same shape as `times` indicating
    `times` indices in `all_times`.
  """
  grid = tf.range(0.0, times[-1], time_step, dtype=dtype)
  all_times = tf.concat([grid, times], axis=0)
  mask = tf.concat([
      tf.zeros_like(grid, dtype=tf.bool),
      tf.ones_like(times, dtype=tf.bool)
  ],
                   axis=0)
  perm = tf.argsort(all_times, stable=True)
  all_times = tf.gather(all_times, perm)
  # Remove duplicate points
  all_times = tf.unique(all_times).y
  time_indices = tf.searchsorted(all_times, times)
  mask = tf.gather(mask, perm)
  return all_times, mask, time_indices 
Example #14
Source File: ito_process.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _prepare_grid(self, times, grid_step):
    """Prepares grid of times for path generation.

    Args:
      times:  Rank 1 `Tensor` of increasing positive real values. The times at
        which the path points are to be evaluated.
      grid_step: Rank 0 real `Tensor`. Maximal distance between points in
        resulting grid.

    Returns:
      Tuple `(all_times, mask)`.
      `all_times` is 1-D real `Tensor` containing all points from 'times` and
      whose intervals are at most `grid_step`.
      `mask` is a boolean 1-D tensor of the same shape as 'all_times', showing
      which elements of 'all_times' correspond to values from `times`.
      Guarantees that times[0]=0 and grid_step[0]=False.
      'all_times` is sorted ascending and may contain duplicates.
    """
    grid = tf.range(0.0, times[-1], grid_step, dtype=self._dtype)
    all_times = tf.concat([grid, times], axis=0)
    mask = tf.concat([
        tf.zeros_like(grid, dtype=tf.bool),
        tf.ones_like(times, dtype=tf.bool)
    ],
                     axis=0)
    perm = tf.argsort(all_times, stable=True)
    all_times = tf.gather(all_times, perm)
    mask = tf.gather(mask, perm)
    return all_times, mask 
Example #15
Source File: stateless.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def stateless_random_shuffle(input_tensor, seed, name=None):
  """Produces stateless random shuffle of the 1st dimension of an input Tensor.

  This is a stateless version of `tf.random_shuffle`. If run twice with the same
  seed, produces the same result.

  Example
  ```python
  identity_shuffle = tf.range(100)
  random_shuffle = stateless_random_shuffle(identity_shuffle, seed=(42, 2))
  ```

  Args:
    input_tensor: float32, float64, int32 or int64 1-D Tensor.
    seed: int32 or int64 Tensor of shape [2].
    name: Python `str` name prefixed to ops created by this function.

  Returns:
    A Tensor of the same shape and dtype as `input_tensor`.
  """
  with tf.compat.v1.name_scope(name,
                               default_name='stateless_random_shuffle',
                               values=[input_tensor, seed]):
    input_tensor = tf.convert_to_tensor(input_tensor, name='input_tensor')
    seed = tf.convert_to_tensor(seed, name='random_seed')
    uniforms = tf.random.stateless_uniform(
        shape=[tf.shape(input_tensor)[0]], seed=seed, dtype=tf.float64)
  return tf.gather(input_tensor, tf.argsort(uniforms, stable=True, axis=0)) 
Example #16
Source File: halton_impl.py    From tf-quant-finance with Apache License 2.0 5 votes vote down vote up
def _randomize(coeffs, radixes, seed, perms=None):
  """Applies the Owen (2017) randomization to the coefficients."""
  given_dtype = coeffs.dtype
  coeffs = tf.cast(coeffs, dtype=tf.int32)
  num_coeffs = _NUM_COEFFS_BY_DTYPE[given_dtype]
  radixes = tf.reshape(tf.cast(radixes, dtype=tf.int32), shape=[-1])
  if perms is None:
    perms = _get_permutations(num_coeffs, radixes, seed)
    perms = tf.reshape(perms, shape=[-1])
  radix_sum = tf.reduce_sum(input_tensor=radixes)
  radix_offsets = tf.reshape(tf.cumsum(radixes, exclusive=True), shape=[-1, 1])
  offsets = radix_offsets + tf.range(num_coeffs) * radix_sum
  permuted_coeffs = tf.gather(perms, coeffs + offsets)
  return tf.cast(permuted_coeffs, dtype=given_dtype), perms 
Example #17
Source File: extensions.py    From trax with Apache License 2.0 5 votes vote down vote up
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs):
  del args, kwargs
  # TODO(nareshmodi): Consider a collective op to gather the tensors from the
  # various devices for performance reasons.
  return tf.stack(value.tensors) 
Example #18
Source File: extensions.py    From trax with Apache License 2.0 4 votes vote down vote up
def sort_key_val(keys, values, dimension=-1):
  """Sorts keys along a dimension and applies same permutation to values.

  Args:
    keys: an array. The dtype must be comparable numbers (integers and reals).
    values: an array, with the same shape of `keys`.
    dimension: an `int`. The dimension along which to sort.

  Returns:
    Permuted keys and values.
  """
  keys = tf_np.asarray(keys)
  values = tf_np.asarray(values)
  rank = keys.data.shape.ndims
  if rank is None:
    rank = values.data.shape.ndims
  if rank is None:
    # We need to know the rank because tf.gather requires batch_dims to be `int`
    raise ValueError("The rank of either keys or values must be known, but "
                     "both are unknown (i.e. their shapes are both None).")
  if dimension in (-1, rank - 1):

    def maybe_swapaxes(a):
      return a
  else:

    def maybe_swapaxes(a):
      return tf_np.swapaxes(a, dimension, -1)

  # We need to swap axes because tf.gather (and tf.gather_nd) supports
  # batch_dims on the left but not on the right.
  # TODO(wangpeng): Investigate whether we should do swapaxes or moveaxis.
  keys = maybe_swapaxes(keys)
  values = maybe_swapaxes(values)
  idxs = tf_np.argsort(keys)
  idxs = idxs.data

  # Using tf.gather rather than np.take because the former supports batch_dims
  def gather(a):
    return tf_np.asarray(tf.gather(a.data, idxs, batch_dims=rank - 1))

  keys = gather(keys)
  values = gather(values)
  keys = maybe_swapaxes(keys)
  values = maybe_swapaxes(values)
  return keys, values


# Use int64 instead of int32 to avoid TF's "int32 problem" 
Example #19
Source File: holiday_utils.py    From tf-quant-finance with Apache License 2.0 4 votes vote down vote up
def _week_day_mappers(weekend_mask):
  """Creates functions to map from ordinals to week days and inverse.

  Creates functions to map from ordinal space (i.e. days since 31 Dec 0) to
  week days. The function assigns the value of 0 to the first non weekend
  day in the week starting on Sunday, 31 Dec 1 through to Saturday, 6 Jan 1 and
  the value assigned to each successive work day is incremented by 1. For a day
  that is not a week day, this count is not incremented from the previous week
  day (hence, multiple ordinal days may have the same week day value).

  Args:
    weekend_mask: A bool `Tensor` of length 7 or None. The weekend mask.

  Returns:
    A tuple of callables.
      `forward`: Takes one `Tensor` argument containing ordinals and returns a
        tuple of two `Tensor`s of the same shape as the input. The first
        `Tensor` is of type `int32` and contains the week day value. The second
        is a bool `Tensor` indicating whether the supplied ordinal was a weekend
        day (i.e. True where the day is a weekend day and False otherwise).
      `backward`: Takes one int32 `Tensor` argument containing week day values
        and returns an int32 `Tensor` containing ordinals for those week days.
  """
  if weekend_mask is None:
    default_forward = lambda x: (x, tf.zeros_like(x, dtype=tf.bool))
    identity = lambda x: x
    return default_forward, identity
  weekend_mask = tf.convert_to_tensor(weekend_mask, dtype=tf.bool)
  weekend_mask = tf.roll(weekend_mask, -_DAYOFWEEK_0, axis=0)
  weekday_mask = tf.logical_not(weekend_mask)
  weekday_offsets = tf.cumsum(tf.cast(weekday_mask, dtype=tf.int32))
  num_workdays = weekday_offsets[-1]
  weekday_offsets -= 1  # Adjust the first workday to index 0.
  ordinal_offsets = tf.convert_to_tensor([0, 1, 2, 3, 4, 5, 6], dtype=tf.int32)
  ordinal_offsets = ordinal_offsets[weekday_mask]

  def forward(ordinals):
    """Adjusts the ordinals by removing the number of weekend days so far."""
    mod, remainder = ordinals // 7, ordinals % 7
    weekday_values = mod * num_workdays + tf.gather(weekday_offsets, remainder)
    is_weekday = tf.gather(weekday_mask, remainder)
    return weekday_values, is_weekday

  def backward(weekday_values):
    """Converts from weekend adjusted values to ordinals."""
    return ((weekday_values // num_workdays) * 7 +
            tf.gather(ordinal_offsets, weekday_values % num_workdays))

  return forward, backward 
Example #20
Source File: date_tensor.py    From tf-quant-finance with Apache License 2.0 4 votes vote down vote up
def from_year_month_day(year, month, day, validate=True):
  """Creates DateTensor from tensors of years, months and days.

  Args:
    year: Tensor of int32 type. Elements should be positive.
    month: Tensor of int32 type of same shape as `year`. Elements should be in
      range `[1, 12]`.
    day: Tensor of int32 type of same shape as `year`. Elements should be in
      range `[1, 31]` and represent valid dates together with corresponding
      elements of `month` and `year` Tensors.
    validate: Whether to validate the dates.

  Returns:
    DateTensor object.

  #### Example

  ```python
  year = tf.constant([2015, 2017], dtype=tf.int32)
  month = tf.constant([4, 12], dtype=tf.int32)
  day = tf.constant([15, 30], dtype=tf.int32)
  date_tensor = tff.datetime.dates_from_year_month_day(year, month, day)
  ```
  """
  year = tf.convert_to_tensor(year, tf.int32)
  month = tf.convert_to_tensor(month, tf.int32)
  day = tf.convert_to_tensor(day, tf.int32)

  control_deps = []
  if validate:
    control_deps.append(
        tf.debugging.assert_positive(year, message="Year must be positive."))
    control_deps.append(
        tf.debugging.assert_greater_equal(
            month,
            constants.Month.JANUARY.value,
            message=f"Month must be >= {constants.Month.JANUARY.value}"))
    control_deps.append(
        tf.debugging.assert_less_equal(
            month,
            constants.Month.DECEMBER.value,
            message="Month must be <= {constants.Month.JANUARY.value}"))
    control_deps.append(
        tf.debugging.assert_positive(day, message="Day must be positive."))
    is_leap = date_utils.is_leap_year(year)
    days_in_months = tf.constant(_DAYS_IN_MONTHS_COMBINED, tf.int32)
    max_days = tf.gather(days_in_months,
                         month + 12 * tf.dtypes.cast(is_leap, np.int32))
    control_deps.append(
        tf.debugging.assert_less_equal(
            day, max_days, message="Invalid day-month pairing."))
    with tf.compat.v1.control_dependencies(control_deps):
      # Ensure years, months, days themselves are under control_deps.
      year = tf.identity(year)
      month = tf.identity(month)
      day = tf.identity(day)

  with tf.compat.v1.control_dependencies(control_deps):
    ordinal = date_utils.year_month_day_to_ordinal(year, month, day)
    return DateTensor(ordinal, year, month, day) 
Example #21
Source File: network.py    From ranking with Apache License 2.0 4 votes vote down vote up
def compute_logits(self,
                     context_features=None,
                     example_features=None,
                     training=None,
                     mask=None):
    """Scores context and examples to return a score per document.

    Args:
      context_features: (dict) context feature names to 2D tensors of shape
        [batch_size, feature_dims].
      example_features: (dict) example feature names to 3D tensors of shape
        [batch_size, list_size, feature_dims].
      training: (bool) whether in train or inference mode.
      mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which
        is True for a valid example and False for invalid one. If mask is None,
        all entries are valid.

    Returns:
      (tf.Tensor) A score tensor of shape [batch_size, list_size].
    """
    tensor = next(six.itervalues(example_features))
    batch_size = tf.shape(tensor)[0]
    list_size = tf.shape(tensor)[1]
    if mask is None:
      mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool)
    nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask)

    # Expand query features to be of [batch_size, list_size, ...].
    large_batch_context_features = {}
    for name, tensor in six.iteritems(context_features):
      x = tf.expand_dims(input=tensor, axis=1)
      x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1)
      large_batch_context_features[name] = utils.reshape_first_ndims(
          x, 2, [batch_size * list_size])

    large_batch_example_features = {}
    for name, tensor in six.iteritems(example_features):
      # Replace invalid example features with valid ones.
      padded_tensor = tf.gather_nd(tensor, nd_indices)
      large_batch_example_features[name] = utils.reshape_first_ndims(
          padded_tensor, 2, [batch_size * list_size])

    # Get scores for large batch.
    scores = self.score(
        context_features=large_batch_context_features,
        example_features=large_batch_example_features,
        training=training)
    logits = tf.reshape(
        scores, shape=[batch_size, list_size])

    # Apply nd_mask to zero out invalid entries.
    logits = tf.where(nd_mask, logits, tf.zeros_like(logits))
    return logits