Python tensorflow.compat.v2.bool() Examples
The following are 30
code examples of tensorflow.compat.v2.bool().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v2
, or try the search function
.
Example #1
Source File: array_ops.py From trax with Apache License 2.0 | 6 votes |
def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin """Whether any element in the entire array or in an axis evaluates to true. Casts the array to bool type if it is not already and uses `tf.reduce_any` to compute the result. Args: a: array_like. Could be an ndarray, a Tensor or any object that can be converted to a Tensor using `tf.convert_to_tensor`. axis: Optional. Could be an int or a tuple of integers. If not specified, the reduction is performed over all array indices. keepdims: If true, retains reduced dimensions with length 1. Returns: An ndarray. Note that unlike NumPy this does not return a scalar bool if `axis` is None. """ a = asarray(a, dtype=bool) return utils.tensor_to_ndarray( tf.reduce_any(input_tensor=a.data, axis=axis, keepdims=keepdims))
Example #2
Source File: wider_face.py From datasets with Apache License 2.0 | 6 votes |
def _info(self): features = { 'image': tfds.features.Image(encoding_format='jpeg'), 'image/filename': tfds.features.Text(), 'faces': tfds.features.Sequence({ 'bbox': tfds.features.BBoxFeature(), 'blur': tf.uint8, 'expression': tf.bool, 'illumination': tf.bool, 'occlusion': tf.uint8, 'pose': tf.bool, 'invalid': tf.bool, }), } return tfds.core.DatasetInfo( builder=self, description=_DESCRIPTION, features=tfds.features.FeaturesDict(features), homepage=_PROJECT_URL, citation=_CITATION, )
Example #3
Source File: blimp.py From datasets with Apache License 2.0 | 6 votes |
def _info(self): return tfds.core.DatasetInfo( builder=self, description=_DESCRIPTION, features=tfds.features.FeaturesDict({ 'sentence_good': tfds.features.Text(), 'sentence_bad': tfds.features.Text(), 'field': tfds.features.Text(), 'linguistics_term': tfds.features.Text(), 'UID': tfds.features.Text(), 'simple_LM_method': tf.bool, 'one_prefix_method': tf.bool, 'two_prefix_method': tf.bool, 'lexically_identical': tf.bool, 'pair_id': tf.int32, }), supervised_keys=None, # Homepage of the dataset for documentation homepage=_PROJECT_URL, citation=_CITATION, )
Example #4
Source File: gap.py From datasets with Apache License 2.0 | 6 votes |
def _info(self): return tfds.core.DatasetInfo( builder=self, description=_DESCRIPTION, features=tfds.features.FeaturesDict({ 'ID': tfds.features.Text(), 'Text': tfds.features.Text(), 'Pronoun': tfds.features.Text(), 'Pronoun-offset': tf.int32, 'A': tfds.features.Text(), 'A-offset': tf.int32, 'A-coref': tf.bool, 'B': tfds.features.Text(), 'B-offset': tf.int32, 'B-coref': tf.bool, 'URL': tfds.features.Text() }), supervised_keys=None, homepage='https://github.com/google-research-datasets/gap-coreference', citation=_CITATION, )
Example #5
Source File: unbounded_holiday_calendar.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def __init__(self, weekend_mask=None, holidays=None): """Initializer. Args: weekend_mask: Boolean `Tensor` of 7 elements one for each day of the week starting with Monday at index 0. A `True` value indicates the day is considered a weekend day and a `False` value implies a week day. Default value: None which means no weekends are applied. holidays: Defines the holidays that are added to the weekends defined by `weekend_mask`. An instance of `dates.DateTensor` or an object convertible to `DateTensor`. Default value: None which means no holidays other than those implied by the weekends (if any). """ if weekend_mask is not None: weekend_mask = tf.cast(weekend_mask, dtype=tf.bool) if holidays is not None: holidays = dt.convert_to_date_tensor(holidays).ordinal() self._to_biz_space, self._from_biz_space = hol.business_day_mappers( weekend_mask=weekend_mask, holidays=holidays)
Example #6
Source File: root_search.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def _should_stop(state, stopping_policy_fn): """Indicates whether the overall Brent search should continue. Args: state: A Python `_BrentSearchState` namedtuple. stopping_policy_fn: Python `callable` controlling the algorithm termination. Returns: A boolean value indicating whether the overall search should continue. """ return tf.convert_to_tensor( stopping_policy_fn(state.finished), name="should_stop", dtype=tf.bool) # This is a direct translation of the Brent root-finding method. # Each operation is guarded by a call to `tf.where` to avoid performing # unnecessary calculations.
Example #7
Source File: network.py From ranking with Apache License 2.0 | 6 votes |
def call(self, inputs=None, training=None, mask=None): """Defines the forward pass for ranking model. Args: inputs: (dict) with a mix of context (2D) and example features (3D). training: (bool) whether in train or inference mode. mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which is True for a valid example and False for invalid one. Returns: (tf.Tensor) A score tensor of shape [batch_size, list_size]. """ context_features, example_features = self.transform( features=inputs, training=training, mask=mask) logits = self.compute_logits( context_features=context_features, example_features=example_features, training=training, mask=mask) return logits
Example #8
Source File: network.py From ranking with Apache License 2.0 | 6 votes |
def compute_logits(self, context_features=None, example_features=None, training=None, mask=None): """Scores context and examples to return a score per document. Args: context_features: (dict) context feature names to 2D tensors of shape [batch_size, feature_dims]. example_features: (dict) example feature names to 3D tensors of shape [batch_size, list_size, feature_dims]. training: (bool) whether in train or inference mode. mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which is True for a valid example and False for invalid one. If mask is None, all entries are valid. Returns: (tf.Tensor) A score tensor of shape [batch_size, list_size]. """ raise NotImplementedError('Calling an abstract method, ' 'tfr.keras.RankingModel.compute_logits().')
Example #9
Source File: network.py From ranking with Apache License 2.0 | 6 votes |
def transform(self, features=None, training=None, mask=None): """Transforms the features into dense context features and example features. The user can overwrite this function for custom transformations. Mask is provided as an argument so that inherited models can have access to it for custom feature transformations, without modifying `call` explicitly. Args: features: (dict) with a mix of context (2D) and example features (3D). training: (bool) whether in train or inference mode. mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which is True for a valid example and False for invalid one. Returns: context_features: (dict) context feature names to dense 2D tensors of shape [batch_size, feature_dims]. example_features: (dict) example feature names to dense 3D tensors of shape [batch_size, list_size, feature_dims]. """ del mask context_features, example_features = self._listwise_dense_layer( inputs=features, training=training) return context_features, example_features
Example #10
Source File: testing_utils.py From valan with Apache License 2.0 | 6 votes |
def __init__(self, state_space_size, unroll_length=1): self._state_space_size = state_space_size # Creates simple dynamics (T stands for transition): # states = [0, 1, ... len(state_space_size - 1)] + [STOP] # actions = [-1, 1] # T(s, a) = s + a iff (s + a) is a valid state # = STOP otherwise self._action_space = [-1, 1] self._current_state = None self._env_spec = common.EnvOutput( reward=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.float32), done=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.bool), observation={ 'f1': tf.TensorSpec( shape=[unroll_length + 1, 4, 10], dtype=tf.float32), 'f2': tf.TensorSpec( shape=[unroll_length + 1, 7, 10, 2], dtype=tf.float32) }, info=tf.TensorSpec(shape=[unroll_length + 1], dtype=tf.string))
Example #11
Source File: pixelcnn.py From alibi-detect with Apache License 2.0 | 6 votes |
def __init__(self, shift, validate_args=False, name='shift'): """Instantiates the `Shift` bijector which computes `Y = g(X; shift) = X + shift` where `shift` is a numeric `Tensor`. Args: shift: Floating-point `Tensor`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. """ with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([shift], dtype_hint=tf.float32) self._shift = tensor_util.convert_nonref_to_tensor(shift, dtype=dtype, name='shift') super(Shift, self).__init__( forward_min_event_ndims=0, is_constant_jacobian=True, dtype=dtype, validate_args=validate_args, name=name )
Example #12
Source File: math_ops.py From trax with Apache License 2.0 | 6 votes |
def diff(a, n=1, axis=-1): def f(a): nd = a.shape.rank if (axis + nd if axis < 0 else axis) >= nd: raise ValueError("axis %s is out of bounds for array of dimension %s" % (axis, nd)) if n < 0: raise ValueError("order must be non-negative but got %s" % n) slice1 = [slice(None)] * nd slice2 = [slice(None)] * nd slice1[axis] = slice(1, None) slice2[axis] = slice(None, -1) slice1 = tuple(slice1) slice2 = tuple(slice2) op = tf.not_equal if a.dtype == tf.bool else tf.subtract for _ in range(n): a = op(a[slice1], a[slice2]) return a return _scalar(f, a)
Example #13
Source File: array_ops.py From trax with Apache License 2.0 | 6 votes |
def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin """Whether all array elements or those along an axis evaluate to true. Casts the array to bool type if it is not already and uses `tf.reduce_all` to compute the result. Args: a: array_like. Could be an ndarray, a Tensor or any object that can be converted to a Tensor using `tf.convert_to_tensor`. axis: Optional. Could be an int or a tuple of integers. If not specified, the reduction is performed over all array indices. keepdims: If true, retains reduced dimensions with length 1. Returns: An ndarray. Note that unlike NumPy this does not return a scalar bool if `axis` is None. """ a = asarray(a, dtype=bool) return utils.tensor_to_ndarray( tf.reduce_all(input_tensor=a.data, axis=axis, keepdims=keepdims))
Example #14
Source File: array_ops.py From trax with Apache License 2.0 | 6 votes |
def tril(m, k=0): # pylint: disable=missing-docstring m = asarray(m).data m_shape = m.shape.as_list() if len(m_shape) < 2: raise ValueError('Argument to tril must have rank at least 2') if m_shape[-1] is None or m_shape[-2] is None: raise ValueError('Currently, the last two dimensions of the input array ' 'need to be known.') z = tf.constant(0, m.dtype) mask = tri(*m_shape[-2:], k=k, dtype=bool) return utils.tensor_to_ndarray( tf.where(tf.broadcast_to(mask, tf.shape(m)), m, z))
Example #15
Source File: array_ops.py From trax with Apache License 2.0 | 6 votes |
def triu(m, k=0): # pylint: disable=missing-docstring m = asarray(m).data m_shape = m.shape.as_list() if len(m_shape) < 2: raise ValueError('Argument to triu must have rank at least 2') if m_shape[-1] is None or m_shape[-2] is None: raise ValueError('Currently, the last two dimensions of the input array ' 'need to be known.') z = tf.constant(0, m.dtype) mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) return utils.tensor_to_ndarray( tf.where(tf.broadcast_to(mask, tf.shape(m)), z, m))
Example #16
Source File: root_search.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _swap_where(condition, x, y): """Swaps the elements of `x` and `y` based on `condition`. Args: condition: A `Tensor` of dtype bool. x: A `Tensor` with the same shape as `condition`. y: A `Tensor` with the same shape and dtype as `x`. Returns: Two `Tensors` with the same shape as `x` and `y`. """ return tf.where(condition, y, x), tf.where(condition, x, y)
Example #17
Source File: nsynth.py From datasets with Apache License 2.0 | 5 votes |
def __init__(self, gansynth_subset=False, estimate_f0_and_loudness=False, **kwargs): """Constructs a NsynthConfig. Args: gansynth_subset: bool, whether to use the subset of the dataset introduced in the ICLR 2019 GANSynth paper (Engel, et al. 2018). This subset uses acoustic-only instrument sources and limits the pitches to the interval [24, 84]. The train and test splits are also modified so that instruments (but not specific notes) overlap between them. See https://arxiv.org/abs/1902.08710 for more details. estimate_f0_and_loudness: bool, whether to estimate fundamental frequency (F0) and loudness for the audio (at 250 Hz) and add them to the set of features. **kwargs: keyword arguments forwarded to super. """ name_parts = [] if gansynth_subset: name_parts.append("gansynth_subset") else: name_parts.append("full") if estimate_f0_and_loudness: name_parts.append("f0_and_loudness") v230 = tfds.core.Version( "2.3.0", "New `loudness_db` feature in decibels (unormalized).") v231 = tfds.core.Version( "2.3.1", "F0 computed with normalization fix in CREPE.") v232 = tfds.core.Version( "2.3.2", "Use Audio feature.") super(NsynthConfig, self).__init__( name=".".join(name_parts), version=v232, supported_versions=[v231, v230], **kwargs) self.gansynth_subset = gansynth_subset self.estimate_f0_and_loudness = estimate_f0_and_loudness
Example #18
Source File: open_images_challenge2019.py From datasets with Apache License 2.0 | 5 votes |
def _info(self): label = tfds.features.ClassLabel(num_classes=_NUM_CLASSES) return tfds.core.DatasetInfo( builder=self, description=_DESCRIPTION + "\n\n" + _DESCRIPTION_DETECTION, features=tfds.features.FeaturesDict({ "id": tfds.features.Text(), "image": tfds.features.Image(), # A sequence of image-level labels. "objects": tfds.features.Sequence({ "label": label, # All labels have been verified by humans. # - If confidence is 1.0, the object IS in the image. # - If confidence is 0.0, the object is NOT in the image. "confidence": tf.float32, "source": tfds.features.Text(), }), # A sequence of bounding boxes. "bobjects": tfds.features.Sequence({ "label": label, "bbox": tfds.features.BBoxFeature(), "is_group_of": tf.bool, }), }), homepage=_URL, )
Example #19
Source File: scicite.py From datasets with Apache License 2.0 | 5 votes |
def _generate_examples(self, path=None): """Yields examples.""" with tf.io.gfile.GFile(path) as f: unique_ids = {} for line in f: d = json.loads(line) unique_id = str(d["unique_id"]) if unique_id in unique_ids: continue unique_ids[unique_id] = True yield unique_id, { "string": d["string"], "label": str(d["label"]), "sectionName": str(d["sectionName"]), "citingPaperId": str(d["citingPaperId"]), "citedPaperId": str(d["citedPaperId"]), "excerpt_index": int(d["excerpt_index"]), "isKeyCitation": bool(d["isKeyCitation"]), "label2": str(d.get("label2", "none")), "citeEnd": _safe_int(d["citeEnd"]), "citeStart": _safe_int(d["citeStart"]), "source": str(d["source"]), "label_confidence": float(d.get("label_confidence", 0.)), "label2_confidence": float(d.get("label2_confidence", 0.)), "id": str(d["id"]), }
Example #20
Source File: array_ops.py From trax with Apache License 2.0 | 5 votes |
def nonzero(a): a = atleast_1d(a).data if a.shape.rank is None: raise ValueError("The rank of `a` is unknown, so we can't decide how many " "arrays to return.") return tf.nest.map_structure( arrays_lib.tensor_to_ndarray, tf.unstack(tf.where(tf.cast(a, tf.bool)), a.shape.rank, axis=1))
Example #21
Source File: bounded_holiday_calendar.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def is_business_day(self, date_tensor): """Returns a tensor of bools for whether given dates are business days.""" is_bus_day_table = self._compute_is_bus_day_table() is_bus_day_int32 = self._gather( is_bus_day_table, date_tensor.ordinal() - self._ordinal_offset + 1) with tf.control_dependencies( self._assert_ordinals_in_bounds(date_tensor.ordinal())): return tf.cast(is_bus_day_int32, dtype=tf.bool)
Example #22
Source File: multi_objective_scalarizer.py From agents with Apache License 2.0 | 5 votes |
def call(self, multi_objectives: tf.Tensor) -> tf.Tensor: transformed_objectives = tf.maximum( multi_objectives * self._slopes + self._offsets, 0) nonzero_mask = tf.broadcast_to( tf.cast(tf.abs(self._direction) >= self.ALMOST_ZERO, dtype=tf.bool), multi_objectives.shape) return tf.reduce_min( tf.where(nonzero_mask, transformed_objectives / self._direction, multi_objectives.dtype.max), axis=1)
Example #23
Source File: vector_hull_white.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _prepare_grid(times, *params): """Prepares grid of times for path generation. Args: times: Rank 1 `Tensor` of increasing positive real values. The times at which the path points are to be evaluated. *params: Parameters of the Heston model. Either scalar `Tensor`s of the same `dtype` or instances of `PiecewiseConstantFunc`. Returns: Tuple `(all_times, mask)`. `all_times` is a 1-D real `Tensor` containing all points from 'times`, the uniform grid of points between `[0, times[-1]]` with grid size equal to `time_step`, and jump locations of piecewise constant parameters The `Tensor` is sorted in ascending order and may contain duplicates. `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing which elements of 'all_times' correspond to THE values from `times`. Guarantees that times[0]=0 and mask[0]=False. """ additional_times = [] for param in params: if hasattr(param, 'is_piecewise_constant'): if param.is_piecewise_constant: # Flatten all jump locations additional_times.append(tf.reshape(param.jump_locations(), [-1])) zeros = tf.constant([0], dtype=times.dtype) all_times = tf.concat([zeros] + [times] + additional_times, axis=0) additional_times_mask = [ tf.zeros_like(times, dtype=tf.bool) for times in additional_times] mask = tf.concat([ tf.cast(zeros, dtype=tf.bool), tf.ones_like(times, dtype=tf.bool) ] + additional_times_mask, axis=0) perm = tf.argsort(all_times, stable=True) all_times = tf.gather(all_times, perm) mask = tf.gather(mask, perm) return all_times, mask
Example #24
Source File: euler_sampling.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _prepare_grid(*, times, time_step, dtype): """Prepares grid of times for path generation. Args: times: Rank 1 `Tensor` of increasing positive real values. The times at which the path points are to be evaluated. time_step: Rank 0 real `Tensor`. Maximal distance between points in resulting grid. dtype: `tf.Dtype` of the input and output `Tensor`s. Returns: Tuple `(all_times, mask, time_points)`. `all_times` is a 1-D real `Tensor` containing all points from 'times` and the uniform grid of points between `[0, times[-1]]` with grid size equal to `time_step`. The `Tensor` is sorted in ascending order and may contain duplicates. `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing which elements of 'all_times' correspond to THE values from `times`. Guarantees that times[0]=0 and mask[0]=False. `time_indices`. An integer `Tensor` of the same shape as `times` indicating `times` indices in `all_times`. """ grid = tf.range(0.0, times[-1], time_step, dtype=dtype) all_times = tf.concat([grid, times], axis=0) mask = tf.concat([ tf.zeros_like(grid, dtype=tf.bool), tf.ones_like(times, dtype=tf.bool) ], axis=0) perm = tf.argsort(all_times, stable=True) all_times = tf.gather(all_times, perm) # Remove duplicate points all_times = tf.unique(all_times).y time_indices = tf.searchsorted(all_times, times) mask = tf.gather(mask, perm) return all_times, mask, time_indices
Example #25
Source File: ito_process.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _prepare_grid(self, times, grid_step): """Prepares grid of times for path generation. Args: times: Rank 1 `Tensor` of increasing positive real values. The times at which the path points are to be evaluated. grid_step: Rank 0 real `Tensor`. Maximal distance between points in resulting grid. Returns: Tuple `(all_times, mask)`. `all_times` is 1-D real `Tensor` containing all points from 'times` and whose intervals are at most `grid_step`. `mask` is a boolean 1-D tensor of the same shape as 'all_times', showing which elements of 'all_times' correspond to values from `times`. Guarantees that times[0]=0 and grid_step[0]=False. 'all_times` is sorted ascending and may contain duplicates. """ grid = tf.range(0.0, times[-1], grid_step, dtype=self._dtype) all_times = tf.concat([grid, times], axis=0) mask = tf.concat([ tf.zeros_like(grid, dtype=tf.bool), tf.ones_like(times, dtype=tf.bool) ], axis=0) perm = tf.argsort(all_times, stable=True) all_times = tf.gather(all_times, perm) mask = tf.gather(mask, perm) return all_times, mask
Example #26
Source File: math_ops.py From trax with Apache License 2.0 | 5 votes |
def maximum(x1, x2): def max_or_or(x1, x2): if x1.dtype == tf.bool: assert x2.dtype == tf.bool return tf.logical_or(x1, x2) return tf.math.maximum(x1, x2) return _bin_op(max_or_or, x1, x2)
Example #27
Source File: implied_vol_approximation.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _validate_args_control_deps(prices, forwards, strikes, expiries, discount_factors, is_call_options): """Returns assertions for no-arbitrage conditions on the prices.""" epsilon = tf.convert_to_tensor(1e-8, dtype=prices.dtype) forwards_positive = tf.compat.v1.debugging.assert_positive(forwards) strikes_positive = tf.compat.v1.debugging.assert_positive(strikes) expiries_positive = tf.compat.v1.debugging.assert_non_negative(expiries) put_lower_bounds = tf.nn.relu(strikes - forwards) call_lower_bounds = tf.nn.relu(forwards - strikes) if is_call_options is not None: is_call_options = tf.convert_to_tensor(is_call_options, dtype=tf.bool, name='is_call_options') lower_bounds = tf.where( is_call_options, x=call_lower_bounds, y=put_lower_bounds) upper_bounds = tf.where(is_call_options, x=forwards, y=strikes) else: lower_bounds = call_lower_bounds upper_bounds = forwards undiscounted_prices = prices / discount_factors bounds_satisfied = [ tf.compat.v1.debugging.assert_less_equal(lower_bounds, undiscounted_prices), tf.compat.v1.debugging.assert_greater_equal(upper_bounds, undiscounted_prices) ] not_too_close_to_bounds = [ tf.compat.v1.debugging.assert_greater( tf.math.abs(undiscounted_prices - lower_bounds), epsilon), tf.compat.v1.debugging.assert_greater( tf.math.abs(undiscounted_prices - upper_bounds), epsilon) ] return [expiries_positive, forwards_positive, strikes_positive ] + bounds_satisfied + not_too_close_to_bounds
Example #28
Source File: pixelcnn.py From alibi-detect with Apache License 2.0 | 5 votes |
def __init__(self, layer, data_init=True, **kwargs): """Initialize WeightNorm wrapper. Args: layer: A `tf.keras.layers.Layer` instance. Supported layer types are `Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs are not supported. data_init: `bool`, if `True` use data dependent variable initialization. **kwargs: Additional keyword args passed to `tf.keras.layers.Wrapper`. Raises: ValueError: If `layer` is not a `tf.keras.layers.Layer` instance. """ if not isinstance(layer, tf.keras.layers.Layer): raise ValueError( 'Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` ' 'instance. You passed: {input}'.format(input=layer) ) layer_type = type(layer).__name__ if layer_type not in ['Dense', 'Conv2D', 'Conv2DTranspose']: warnings.warn('`WeightNorm` is tested only for `Dense`, `Conv2D`, and ' '`Conv2DTranspose` layers. You passed a layer of type `{}`' .format(layer_type)) super(WeightNorm, self).__init__(layer, **kwargs) self.data_init = data_init self._track_trackable(layer, name='layer') self.filter_axis = -2 if layer_type == 'Conv2DTranspose' else -1
Example #29
Source File: math_ops.py From trax with Apache License 2.0 | 5 votes |
def minimum(x1, x2): def min_or_and(x1, x2): if x1.dtype == tf.bool: assert x2.dtype == tf.bool return tf.logical_and(x1, x2) return tf.math.minimum(x1, x2) return _bin_op(min_or_and, x1, x2)
Example #30
Source File: math_ops.py From trax with Apache License 2.0 | 5 votes |
def _bitwise_binary_op(tf_fn, x1, x2): def f(x1, x2): is_bool = (x1.dtype == tf.bool) if is_bool: assert x2.dtype == tf.bool x1 = tf.cast(x1, tf.int8) x2 = tf.cast(x2, tf.int8) r = tf_fn(x1, x2) if is_bool: r = tf.cast(r, tf.bool) return r return _bin_op(f, x1, x2)