Python tensorflow.python.ops.array_ops.one_hot() Examples
The following are 30
code examples of tensorflow.python.ops.array_ops.one_hot().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.ops.array_ops
, or try the search function
.
Example #1
Source File: attention_wrapper_mod.py From NQG_ASs2s with MIT License | 6 votes |
def hardmax(logits, name=None): """Returns batched one-hot vectors. The depth index containing the `1` is that of the maximum logit value. Args: logits: A batch tensor of logit values. name: Name to use when creating ops. Returns: A batched one-hot tensor. """ with ops.name_scope(name, "Hardmax", [logits]): logits = ops.convert_to_tensor(logits, name="logits") if logits.get_shape()[-1].value is not None: depth = logits.get_shape()[-1].value else: depth = array_ops.shape(logits)[-1] return array_ops.one_hot( math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
Example #2
Source File: dataset_ops.py From lambda-packs with MIT License | 6 votes |
def _estimate_data_distribution(c, num_examples_per_class_seen): """Estimate data distribution as labels are seen. Args: c: The class labels. Type `int32`, shape `[batch_size]`. num_examples_per_class_seen: A `ResourceVariable` containing counts. Type `int64`, shape `[num_classes]`. Returns: dist: The updated distribution. Type `float32`, shape `[num_classes]`. """ num_classes = num_examples_per_class_seen.get_shape()[0].value # Update the class-count based on what labels are seen in # batch. But do this asynchronously to avoid performing a # cross-device round-trip. Just use the cached value. num_examples_per_class_seen = num_examples_per_class_seen.assign_add( math_ops.reduce_sum( array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) init_prob_estimate = math_ops.truediv( num_examples_per_class_seen, math_ops.reduce_sum(num_examples_per_class_seen)) return math_ops.cast(init_prob_estimate, dtypes.float32)
Example #3
Source File: copy_attention_wrapper.py From question-generation with MIT License | 6 votes |
def hardmax(logits, name=None): """Returns batched one-hot vectors. The depth index containing the `1` is that of the maximum logit value. Args: logits: A batch tensor of logit values. name: Name to use when creating ops. Returns: A batched one-hot tensor. """ with ops.name_scope(name, "Hardmax", [logits]): logits = ops.convert_to_tensor(logits, name="logits") if logits.get_shape()[-1].value is not None: depth = logits.get_shape()[-1].value else: depth = array_ops.shape(logits)[-1] return array_ops.one_hot( math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
Example #4
Source File: copy_attention_wrapper.py From question-generation with MIT License | 6 votes |
def initial_alignments(self, batch_size, dtype): """Creates the initial alignment values for the monotonic attentions. Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] for all entries in the batch. Args: batch_size: `int32` scalar, the batch_size. dtype: The `dtype`. Returns: A `dtype` tensor shaped `[batch_size, alignments_size]` (`alignments_size` is the values' `max_time`). """ max_time = self._alignments_size return array_ops.one_hot( array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, dtype=dtype)
Example #5
Source File: attention_wrapper.py From lambda-packs with MIT License | 6 votes |
def hardmax(logits, name=None): """Returns batched one-hot vectors. The depth index containing the `1` is that of the maximum logit value. Args: logits: A batch tensor of logit values. name: Name to use when creating ops. Returns: A batched one-hot tensor. """ with ops.name_scope(name, "Hardmax", [logits]): logits = ops.convert_to_tensor(logits, name="logits") if logits.get_shape()[-1].value is not None: depth = logits.get_shape()[-1].value else: depth = array_ops.shape(logits)[-1] return array_ops.one_hot( math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
Example #6
Source File: backend.py From lambda-packs with MIT License | 6 votes |
def one_hot(indices, num_classes): """Computes the one-hot representation of an integer tensor. Arguments: indices: nD integer tensor of shape `(batch_size, dim1, dim2, ... dim(n-1))` num_classes: Integer, number of classes to consider. Returns: (n + 1)D one hot representation of the input with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)` Returns: The one-hot tensor. """ return array_ops.one_hot(indices, depth=num_classes, axis=-1)
Example #7
Source File: multinomial.py From keras-lambda with MIT License | 6 votes |
def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.n, dtype=dtypes.int32) if self.n.get_shape().ndims is not None: if self.n.get_shape().ndims != 0: raise NotImplementedError( "Sample only supported for scalar number of draws.") elif self.validate_args: is_scalar = check_ops.assert_rank( n_draws, 0, message="Sample only supported for scalar number of draws.") n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws) k = self.event_shape()[0] # Flatten batch dims so logits has shape [B, k], # where B = reduce_prod(self.batch_shape()). logits = array_ops.reshape(self.logits, [-1, k]) draws = random_ops.multinomial(logits=logits, num_samples=n * n_draws, seed=seed) draws = array_ops.reshape(draws, shape=[-1, n, n_draws]) x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), reduction_indices=-2) # shape: [B, n, k] x = array_ops.transpose(x, perm=[1, 0, 2]) final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0) return array_ops.reshape(x, final_shape)
Example #8
Source File: multinomial.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 6 votes |
def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) if self.total_count.get_shape().ndims is not None: if self.total_count.get_shape().ndims != 0: raise NotImplementedError( "Sample only supported for scalar number of draws.") elif self.validate_args: is_scalar = check_ops.assert_rank( n_draws, 0, message="Sample only supported for scalar number of draws.") n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws) k = self.event_shape_tensor()[0] # Flatten batch dims so logits has shape [B, k], # where B = reduce_prod(self.batch_shape_tensor()). x = random_ops.multinomial( logits=array_ops.reshape(self.logits, [-1, k]), num_samples=n * n_draws, seed=seed) x = array_ops.reshape(x, shape=[-1, n, n_draws]) x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2) # shape: [B, n, k] x = array_ops.transpose(x, perm=[1, 0, 2]) final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) x = array_ops.reshape(x, final_shape) return math_ops.cast(x, self.dtype)
Example #9
Source File: dirichlet_multinomial.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 6 votes |
def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) k = self.event_shape_tensor()[0] unnormalized_logits = array_ops.reshape( math_ops.log(random_ops.random_gamma( shape=[n], alpha=self.concentration, dtype=self.dtype, seed=seed)), shape=[-1, k]) draws = random_ops.multinomial( logits=unnormalized_logits, num_samples=n_draws, seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial")) x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2) final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) x = array_ops.reshape(x, final_shape) return math_ops.cast(x, self.dtype)
Example #10
Source File: backend.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 6 votes |
def one_hot(indices, num_classes): """Computes the one-hot representation of an integer tensor. Arguments: indices: nD integer tensor of shape `(batch_size, dim1, dim2, ... dim(n-1))` num_classes: Integer, number of classes to consider. Returns: (n + 1)D one hot representation of the input with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)` Returns: The one-hot tensor. """ return array_ops.one_hot(indices, depth=num_classes, axis=-1)
Example #11
Source File: multinomial.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.n, dtype=dtypes.int32) if self.n.get_shape().ndims is not None: if self.n.get_shape().ndims != 0: raise NotImplementedError( "Sample only supported for scalar number of draws.") elif self.validate_args: is_scalar = check_ops.assert_rank( n_draws, 0, message="Sample only supported for scalar number of draws.") n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws) k = self.event_shape()[0] # Flatten batch dims so logits has shape [B, k], # where B = reduce_prod(self.batch_shape()). logits = array_ops.reshape(self.logits, [-1, k]) draws = random_ops.multinomial(logits=logits, num_samples=n * n_draws, seed=seed) draws = array_ops.reshape(draws, shape=[-1, n, n_draws]) x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), reduction_indices=-2) # shape: [B, n, k] x = array_ops.transpose(x, perm=[1, 0, 2]) final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0) return array_ops.reshape(x, final_shape)
Example #12
Source File: attention_wrapper.py From CommonSenseMultiHopQA with MIT License | 6 votes |
def hardmax(logits, name=None): """Returns batched one-hot vectors. The depth index containing the `1` is that of the maximum logit value. Args: logits: A batch tensor of logit values. name: Name to use when creating ops. Returns: A batched one-hot tensor. """ with ops.name_scope(name, "Hardmax", [logits]): logits = ops.convert_to_tensor(logits, name="logits") if logits.get_shape()[-1].value is not None: depth = logits.get_shape()[-1].value else: depth = array_ops.shape(logits)[-1] return array_ops.one_hot( math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
Example #13
Source File: attention_wrapper.py From CommonSenseMultiHopQA with MIT License | 6 votes |
def initial_alignments(self, batch_size, dtype): """Creates the initial alignment values for the monotonic attentions. Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] for all entries in the batch. Args: batch_size: `int32` scalar, the batch_size. dtype: The `dtype`. Returns: A `dtype` tensor shaped `[batch_size, alignments_size]` (`alignments_size` is the values' `max_time`). """ max_time = self._alignments_size return array_ops.one_hot( array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, dtype=dtype)
Example #14
Source File: array_ops.py From deep_image_model with Apache License 2.0 | 6 votes |
def one_hot_matrix(tensor_in, num_classes, on_value=1.0, off_value=0.0): """Encodes indices from given tensor as one-hot tensor. TODO(ilblackdragon): Ideally implementation should be part of TensorFlow with Eigen-native operation. Args: tensor_in: Input tensor of shape [N1, N2]. num_classes: Number of classes to expand index into. on_value: Tensor or float, value to fill-in given index. off_value: Tensor or float, value to fill-in everything else. Returns: Tensor of shape [N1, N2, num_classes] with 1.0 for each id in original tensor. """ return array_ops_.one_hot( math_ops.cast(tensor_in, dtypes.int64), num_classes, on_value, off_value)
Example #15
Source File: attention_wrapper.py From OpenSeq2Seq with Apache License 2.0 | 6 votes |
def hardmax(logits, name=None): """Returns batched one-hot vectors. The depth index containing the `1` is that of the maximum logit value. Args: logits: A batch tensor of logit values. name: Name to use when creating ops. Returns: A batched one-hot tensor. """ with ops.name_scope(name, "Hardmax", [logits]): logits = ops.convert_to_tensor(logits, name="logits") if logits.get_shape()[-1].value is not None: depth = logits.get_shape()[-1].value else: depth = array_ops.shape(logits)[-1] return array_ops.one_hot( math_ops.argmax(logits, -1), depth, dtype=logits.dtype )
Example #16
Source File: attention_wrapper.py From OpenSeq2Seq with Apache License 2.0 | 6 votes |
def initial_alignments(self, batch_size, dtype): """Creates the initial alignment values for the monotonic attentions. Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] for all entries in the batch. Args: batch_size: `int32` scalar, the batch_size. dtype: The `dtype`. Returns: A `dtype` tensor shaped `[batch_size, alignments_size]` (`alignments_size` is the values' `max_time`). """ max_time = self._alignments_size return array_ops.one_hot( array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, dtype=dtype )
Example #17
Source File: attention_wrapper.py From QGforQA with MIT License | 6 votes |
def hardmax(logits, name=None): """Returns batched one-hot vectors. The depth index containing the `1` is that of the maximum logit value. Args: logits: A batch tensor of logit values. name: Name to use when creating ops. Returns: A batched one-hot tensor. """ with ops.name_scope(name, "Hardmax", [logits]): logits = ops.convert_to_tensor(logits, name="logits") if logits.get_shape()[-1].value is not None: depth = logits.get_shape()[-1].value else: depth = array_ops.shape(logits)[-1] return array_ops.one_hot( math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
Example #18
Source File: rf3.py From deep-learning with MIT License | 6 votes |
def _get_eval_ops(self, features, targets, metrics): features, spec = data_ops.ParseDataTensorOrDict(features) labels = data_ops.ParseLabelTensorOrDict(targets) graph_builder = self.graph_builder_class( self.params, device_assigner=self.device_assigner, training=False, **self.construction_args) probabilities = graph_builder.inference_graph(features, data_spec=spec) # One-hot the labels. if not self.params.regression: labels = math_ops.to_int64(array_ops.one_hot(math_ops.to_int64( array_ops.squeeze(labels)), self.params.num_classes, 1, 0)) if metrics is None: metrics = {self.accuracy_metric: eval_metrics.get_metric(self.accuracy_metric)} result = {} for name, metric in six.iteritems(metrics): result[name] = metric(probabilities, labels) return result
Example #19
Source File: attention_wrapper.py From tf-var-attention with MIT License | 6 votes |
def initial_alignments(self, batch_size, dtype): """Creates the initial alignment values for the monotonic attentions. Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] for all entries in the batch. Args: batch_size: `int32` scalar, the batch_size. dtype: The `dtype`. Returns: A `dtype` tensor shaped `[batch_size, alignments_size]` (`alignments_size` is the values' `max_time`). """ max_time = self._alignments_size return array_ops.one_hot( array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, dtype=dtype)
Example #20
Source File: attention_wrapper.py From tf-var-attention with MIT License | 6 votes |
def hardmax(logits, name=None): """Returns batched one-hot vectors. The depth index containing the `1` is that of the maximum logit value. Args: logits: A batch tensor of logit values. name: Name to use when creating ops. Returns: A batched one-hot tensor. """ with ops.name_scope(name, "Hardmax", [logits]): logits = ops.convert_to_tensor(logits, name="logits") if logits.get_shape()[-1].value is not None: depth = logits.get_shape()[-1].value else: depth = array_ops.shape(logits)[-1] return array_ops.one_hot( math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
Example #21
Source File: multinomial.py From lambda-packs with MIT License | 6 votes |
def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) if self.total_count.get_shape().ndims is not None: if self.total_count.get_shape().ndims != 0: raise NotImplementedError( "Sample only supported for scalar number of draws.") elif self.validate_args: is_scalar = check_ops.assert_rank( n_draws, 0, message="Sample only supported for scalar number of draws.") n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws) k = self.event_shape_tensor()[0] # Flatten batch dims so logits has shape [B, k], # where B = reduce_prod(self.batch_shape_tensor()). draws = random_ops.multinomial( logits=array_ops.reshape(self.logits, [-1, k]), num_samples=n * n_draws, seed=seed) draws = array_ops.reshape(draws, shape=[-1, n, n_draws]) x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), axis=-2) # shape: [B, n, k] x = array_ops.transpose(x, perm=[1, 0, 2]) final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) return array_ops.reshape(x, final_shape)
Example #22
Source File: devel.py From avsr-tf1 with GNU General Public License v3.0 | 6 votes |
def focal_loss(labels, logits, gamma=2.0): r""" Multi-class focal loss implementation: https://arxiv.org/abs/1708.02002 :param labels: [batch_size, ] - Tensor of the correct class ids :param logits: [batch_size, num_classes] - Unscaled logits :param gamma: focal loss weight :return: [batch_size, ] - Tensor of average costs for each batch element """ num_classes = array_ops.shape(logits)[1] onehot_labels = array_ops.one_hot(labels, num_classes, dtype=logits.dtype) p = nn_ops.softmax(logits) p = clip_ops.clip_by_value(p, 1e-7, 1.0 - 1e-7) f_loss = - onehot_labels * math_ops.pow(1.0 - p, gamma) * math_ops.log(p) \ - (1 - onehot_labels) * math_ops.pow(p, gamma) * math_ops.log(1.0 - p) cost = math_ops.reduce_sum(f_loss, axis=1) return cost
Example #23
Source File: devel.py From avsr-tf1 with GNU General Public License v3.0 | 6 votes |
def mc_loss(labels, logits): r""" A multi-class cross-entropy loss :param labels: [batch_size, ] - Tensor of the correct class ids :param logits: [batch_size, num_classes] - Unscaled logits :return: [batch_size, ] - Tensor of average costs for each batch element """ num_classes = array_ops.shape(logits)[1] onehot_labels = array_ops.one_hot(labels, num_classes, dtype=logits.dtype) p = nn_ops.softmax(logits) p = clip_ops.clip_by_value(p, 1e-7, 1.0 - 1e-7) ce_loss = - onehot_labels * math_ops.log(p) - (1 - onehot_labels) * math_ops.log(1.0-p) cost = math_ops.reduce_sum(ce_loss, axis=1) return cost
Example #24
Source File: attention_wrapper.py From QGforQA with MIT License | 6 votes |
def initial_alignments(self, batch_size, dtype): """Creates the initial alignment values for the monotonic attentions. Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] for all entries in the batch. Args: batch_size: `int32` scalar, the batch_size. dtype: The `dtype`. Returns: A `dtype` tensor shaped `[batch_size, alignments_size]` (`alignments_size` is the values' `max_time`). """ max_time = self._alignments_size return array_ops.one_hot( array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, dtype=dtype)
Example #25
Source File: attention_wrapper_mod.py From NQG_ASs2s with MIT License | 6 votes |
def initial_alignments(self, batch_size, dtype): """Creates the initial alignment values for the monotonic attentions. Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] for all entries in the batch. Args: batch_size: `int32` scalar, the batch_size. dtype: The `dtype`. Returns: A `dtype` tensor shaped `[batch_size, alignments_size]` (`alignments_size` is the values' `max_time`). """ max_time = self._alignments_size return array_ops.one_hot( array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, dtype=dtype)
Example #26
Source File: dirichlet_multinomial.py From lambda-packs with MIT License | 6 votes |
def _sample_n(self, n, seed=None): n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) k = self.event_shape_tensor()[0] unnormalized_logits = array_ops.reshape( math_ops.log(random_ops.random_gamma( shape=[n], alpha=self.concentration, dtype=self.dtype, seed=seed)), shape=[-1, k]) draws = random_ops.multinomial( logits=unnormalized_logits, num_samples=n_draws, seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial")) x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2) final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) return array_ops.reshape(x, final_shape)
Example #27
Source File: onehot_categorical.py From keras-lambda with MIT License | 5 votes |
def _mode(self): ret = math_ops.argmax(self.logits, axis=self._batch_rank) ret = array_ops.one_hot(ret, self.num_classes, dtype=self.dtype) ret.set_shape(self.logits.get_shape()) return ret
Example #28
Source File: onehot_categorical.py From keras-lambda with MIT License | 5 votes |
def _sample_n(self, n, seed=None): sample_shape = array_ops.concat(([n], array_ops.shape(self.logits)), 0) logits = self.logits if logits.get_shape().ndims == 2: logits_2d = logits else: logits_2d = array_ops.reshape(logits, [-1, self.num_classes]) samples = random_ops.multinomial(logits_2d, n, seed=seed) samples = array_ops.transpose(samples) samples = array_ops.one_hot(samples, self.num_classes, dtype=self.dtype) ret = array_ops.reshape(samples, sample_shape) return ret
Example #29
Source File: von_mises_fisher.py From s-vae-tf with MIT License | 5 votes |
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name="von-Mises-Fisher"): """Construct von-Mises-Fisher distributions with mean and concentration `loc` and `scale`. Args: loc: Floating point tensor; the mean of the distribution(s). scale: Floating point tensor; the concentration of the distribution(s). Must contain only non-negative values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `loc` and `scale` have different `dtype`. """ parameters = locals() with ops.name_scope(name, values=[loc, scale]): with ops.control_dependencies([check_ops.assert_positive(scale), check_ops.assert_near(linalg_ops.norm(loc, axis=-1), 1, atol=1e-7)] if validate_args else []): self._loc = array_ops.identity(loc, name="loc") self._scale = array_ops.identity(scale, name="scale") check_ops.assert_same_float_dtype([self._loc, self._scale]) super(VonMisesFisher, self).__init__( dtype=self._scale.dtype, reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._loc, self._scale], name=name) self.__m = math_ops.cast(self._loc.shape[-1], dtypes.int32) self.__mf = math_ops.cast(self.__m, dtype=self.dtype) self.__e1 = array_ops.one_hot([0], self.__m, dtype=self.dtype)
Example #30
Source File: sampling_ops.py From keras-lambda with MIT License | 5 votes |
def _estimate_data_distribution(labels, num_classes, smoothing_constant=10): """Estimate data distribution as labels are seen.""" # Variable to track running count of classes. Smooth by a nonzero value to # avoid division-by-zero. Higher values provide more stability at the cost of # slower convergence. if smoothing_constant <= 0: raise ValueError('smoothing_constant must be nonzero.') num_examples_per_class_seen = variables.Variable( initial_value=[smoothing_constant] * num_classes, trainable=False, name='class_count', dtype=dtypes.int64) # Update the class-count based on what labels are seen in batch. num_examples_per_class_seen = num_examples_per_class_seen.assign_add( math_ops.reduce_sum( array_ops.one_hot( labels, num_classes, dtype=dtypes.int64), 0)) # Normalize count into a probability. # NOTE: Without the `+= 0` line below, the test # `testMultiThreadedEstimateDataDistribution` fails. The reason is that # before this line, `num_examples_per_class_seen` is a Tensor that shares a # buffer with an underlying `ref` object. When the `ref` is changed by another # thread, `num_examples_per_class_seen` changes as well. Since this can happen # in the middle of the normalization computation, we get probabilities that # are very far from summing to one. Adding `+= 0` copies the contents of the # tensor to a new buffer, which will be consistent from the start to the end # of the normalization computation. num_examples_per_class_seen += 0 init_prob_estimate = math_ops.truediv( num_examples_per_class_seen, math_ops.reduce_sum(num_examples_per_class_seen)) # Must return float32 (not float64) to agree with downstream `_verify_input` # checks. return math_ops.cast(init_prob_estimate, dtypes.float32)