Python tensorflow.compat.v2.cast() Examples
The following are 30
code examples of tensorflow.compat.v2.cast().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v2
, or try the search function
.
Example #1
Source File: euler_sampling.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def _euler_step(*, i, written_count, current_state, result, drift_fn, volatility_fn, wiener_mean, num_samples, times, dt, sqrt_dt, keep_mask, random_type, seed, normal_draws): """Performs one step of Euler scheme.""" current_time = times[i + 1] written_count = tf.cast(written_count, tf.int32) if normal_draws is not None: dw = normal_draws[i] else: dw = random.mv_normal_sample( (num_samples,), mean=wiener_mean, random_type=random_type, seed=seed) dw = dw * sqrt_dt[i] dt_inc = dt[i] * drift_fn(current_time, current_state) # pylint: disable=not-callable dw_inc = tf.linalg.matvec(volatility_fn(current_time, current_state), dw) # pylint: disable=not-callable next_state = current_state + dt_inc + dw_inc result = utils.maybe_update_along_axis( tensor=result, do_update=keep_mask[i + 1], ind=written_count, axis=1, new_tensor=tf.expand_dims(next_state, axis=1)) written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32) return i + 1, written_count, next_state, result
Example #2
Source File: extensions.py From trax with Apache License 2.0 | 6 votes |
def _key2seed(a): """Converts an RNG key to an RNG seed. Args: a: an RNG key, an ndarray of shape [] and dtype `np.int64`. Returns: an RNG seed, a tensor of shape [2] and dtype `tf.int32`. """ def int64_to_int32s(a): """Converts an int64 tensor of shape [] to an int32 tensor of shape [2].""" a = tf.cast(a, tf.uint64) fst = tf.cast(a, tf.uint32) snd = tf.cast( tf.bitwise.right_shift(a, tf.constant(32, tf.uint64)), tf.uint32) a = [fst, snd] a = tf.nest.map_structure(lambda x: tf.cast(x, tf.int32), a) a = tf.stack(a) return a return int64_to_int32s(a.data)
Example #3
Source File: extensions.py From trax with Apache License 2.0 | 6 votes |
def _seed2key(a): """Converts an RNG seed to an RNG key. Args: a: an RNG seed, a tensor of shape [2] and dtype `tf.int32`. Returns: an RNG key, an ndarray of shape [] and dtype `np.int64`. """ def int32s_to_int64(a): """Converts an int32 tensor of shape [2] to an int64 tensor of shape [].""" a = tf.bitwise.bitwise_or( tf.cast(a[0], tf.uint64), tf.bitwise.left_shift( tf.cast(a[1], tf.uint64), tf.constant(32, tf.uint64))) a = tf.cast(a, tf.int64) return a return tf_np.asarray(int32s_to_int64(a))
Example #4
Source File: math_ops.py From trax with Apache License 2.0 | 6 votes |
def true_divide(x1, x2): def _avoid_float64(x1, x2): if x1.dtype == x2.dtype and x1.dtype in (tf.int32, tf.int64): x1 = tf.cast(x1, dtype=tf.float32) x2 = tf.cast(x2, dtype=tf.float32) return x1, x2 def f(x1, x2): if x1.dtype == tf.bool: assert x2.dtype == tf.bool float_ = dtypes.default_float_type() x1 = tf.cast(x1, float_) x2 = tf.cast(x2, float_) if not dtypes.is_allow_float64(): # tf.math.truediv in Python3 produces float64 when both inputs are int32 # or int64. We want to avoid that when is_allow_float64() is False. x1, x2 = _avoid_float64(x1, x2) return tf.math.truediv(x1, x2) return _bin_op(f, x1, x2)
Example #5
Source File: exporter_lib_v2.py From models with Apache License 2.0 | 6 votes |
def _run_inference_on_images(self, image): """Cast image to float and run inference. Args: image: uint8 Tensor of shape [1, None, None, 3] Returns: Tensor dictionary holding detections. """ label_id_offset = 1 image = tf.cast(image, tf.float32) image, shapes = self._model.preprocess(image) prediction_dict = self._model.predict(image, shapes) detections = self._model.postprocess(prediction_dict, shapes) classes_field = fields.DetectionResultFields.detection_classes detections[classes_field] = ( tf.cast(detections[classes_field], tf.float32) + label_id_offset) for key, val in detections.items(): detections[key] = tf.cast(val, tf.float32) return detections
Example #6
Source File: array_ops.py From trax with Apache License 2.0 | 6 votes |
def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-docstring M = M if M is not None else N if dtype is not None: dtype = utils.result_type(dtype) else: dtype = dtypes.default_float_type() if k < 0: lower = -k - 1 if lower > N: r = tf.zeros([N, M], dtype) else: # Keep as tf bool, since we create an upper triangular matrix and invert # it. o = tf.ones([N, M], dtype=tf.bool) r = tf.cast(tf.math.logical_not(tf.linalg.band_part(o, lower, -1)), dtype) else: o = tf.ones([N, M], dtype) if k > M: r = o else: r = tf.linalg.band_part(o, -1, k) return utils.tensor_to_ndarray(r)
Example #7
Source File: grad_utils.py From models with Apache License 2.0 | 6 votes |
def _filter_and_allreduce_gradients(grads_and_vars, allreduce_precision="float32"): """Filter None grads and then allreduce gradients in specified precision. This utils function is used when users intent to explicitly allreduce gradients and customize gradients operations before and after allreduce. The allreduced gradients are then passed to optimizer.apply_gradients( experimental_aggregate_gradients=False). Arguments: grads_and_vars: gradients and variables pairs. allreduce_precision: Whether to allreduce gradients in float32 or float16. Returns: pairs of allreduced non-None gradients and variables. """ filtered_grads_and_vars = _filter_grads(grads_and_vars) (grads, variables) = zip(*filtered_grads_and_vars) if allreduce_precision == "float16": grads = [tf.cast(grad, "float16") for grad in grads] allreduced_grads = tf.distribute.get_replica_context().all_reduce( tf.distribute.ReduceOp.SUM, grads) if allreduce_precision == "float16": allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads] return allreduced_grads, variables
Example #8
Source File: array_ops.py From trax with Apache License 2.0 | 6 votes |
def around(a, decimals=0): # pylint: disable=missing-docstring a = asarray(a) dtype = a.dtype factor = math.pow(10, decimals) if np.issubdtype(dtype, np.inexact): factor = tf.cast(factor, dtype) else: # Use float as the working dtype when a.dtype is exact (e.g. integer), # because `decimals` can be negative. float_dtype = dtypes.default_float_type() a = a.astype(float_dtype).data factor = tf.cast(factor, float_dtype) a = tf.multiply(a, factor) a = tf.round(a) a = tf.math.divide(a, factor) return utils.tensor_to_ndarray(a).astype(dtype)
Example #9
Source File: math_ops.py From trax with Apache License 2.0 | 6 votes |
def _scalar(tf_fn, x, promote_to_float=False): """Computes the tf_fn(x) for each element in `x`. Args: tf_fn: function that takes a single Tensor argument. x: array_like. Could be an ndarray, a Tensor or any object that can be converted to a Tensor using `tf.convert_to_tensor`. promote_to_float: whether to cast the argument to a float dtype (`dtypes.default_float_type`) if it is not already. Returns: An ndarray with the same shape as `x`. The default output dtype is determined by `dtypes.default_float_type`, unless x is an ndarray with a floating point type, in which case the output type is same as x.dtype. """ x = array_ops.asarray(x) if promote_to_float and not np.issubdtype(x.dtype, np.inexact): x = x.astype(dtypes.default_float_type()) return utils.tensor_to_ndarray(tf_fn(x.data))
Example #10
Source File: baseline_agent.py From valan with Apache License 2.0 | 6 votes |
def _torso(self, observation): conv_out = observation[streetview_constants.IMAGE_FEATURES] heading = observation[streetview_constants.HEADING] last_action = observation[streetview_constants.PREV_ACTION_IDX] conv_out = tf.cast(conv_out, tf.float32) img_encoding = self._dense_img_extra(self._dense_img(conv_out)) img_encoding = tf.keras.layers.Flatten()(img_encoding) heading = tf.expand_dims(heading, -1) last_action_embedded = self._action_embedder(last_action) torso_output = tf.concat([heading, last_action_embedded, img_encoding], axis=1) timestep_embedded = self._timestep_embedder( observation[streetview_constants.TIMESTEP]) return { 'neck_input': torso_output, streetview_constants.TIMESTEP: timestep_embedded, }
Example #11
Source File: mt_agent.py From valan with Apache License 2.0 | 6 votes |
def _neck(self, torso_outputs, state): current_lstm_state, text_enc_outputs, ins_classifier_logits = state image_features = tf.cast(torso_outputs[constants.PANO_ENC], tf.float32) lstm_output, next_lstm_state = self._image_encoder(image_features, current_lstm_state) lstm_output = tf.expand_dims(lstm_output, axis=1) # c_text has shape [batch_size, 1, self._text_attention_size] c_text = self._text_attention([ self._text_attention_project_hidden(lstm_output), self._text_attention_project_text(text_enc_outputs) ]) # The next_lstm_state are ListWrappers. In order to make it consistent with # get_initial_state, we convert them to tuple. result_state = [] for one_state in next_lstm_state: result_state.append((one_state[0], one_state[1])) torso_outputs['hidden_state'] = lstm_output torso_outputs['c_text'] = c_text torso_outputs['ins_classifier_logits'] = ins_classifier_logits return (torso_outputs, (result_state, text_enc_outputs, ins_classifier_logits))
Example #12
Source File: metrics.py From ranking with Apache License 2.0 | 6 votes |
def update_state(self, y_true, y_pred, sample_weight=None): """Accumulates metric statistics. `y_true` and `y_pred` should have the same shape. Args: y_true: The ground truth values. y_pred: The predicted values. sample_weight: Optional weighting of each example. Defaults to 1. Can be a `Tensor` whose rank is either 0, or the same rank as `y_true`, and must be broadcastable to `y_true`. Returns: Update op. """ y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) per_list_metric_val, per_list_metric_weights = self._metric.compute( y_true, y_pred, sample_weight) return super(_RankingMetric, self).update_state( per_list_metric_val, sample_weight=per_list_metric_weights)
Example #13
Source File: crr_binomial_tree.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def _get_payoff_fn(strikes, is_call_options): """Constructs the payoff functions.""" option_signs = tf.cast(is_call_options, dtype=strikes.dtype) * 2 - 1 def payoff(spots): """Computes payff for the specified options given the spot grid. Args: spots: Tensor of shape [batch_size, grid_size, 1]. The spot values at some time. Returns: Payoffs for exercise at the specified strikes. """ return tf.nn.relu((spots - strikes) * option_signs) return payoff
Example #14
Source File: inner_reshape.py From agents with Apache License 2.0 | 6 votes |
def _reshape_inner_dims( tensor: tf.Tensor, shape: tf.TensorShape, new_shape: tf.TensorShape) -> tf.Tensor: """Reshapes tensor to: shape(tensor)[:-len(shape)] + new_shape.""" tensor_shape = tf.shape(tensor) ndims = shape.rank tensor.shape[-ndims:].assert_is_compatible_with(shape) new_shape_inner_tensor = tf.cast( [-1 if d is None else d for d in new_shape.as_list()], tf.int64) new_shape_outer_tensor = tf.cast( tensor_shape[:-ndims], tf.int64) full_new_shape = tf.concat( (new_shape_outer_tensor, new_shape_inner_tensor), axis=0) new_tensor = tf.reshape(tensor, full_new_shape) new_tensor.set_shape(tensor.shape[:-ndims] + new_shape) return new_tensor
Example #15
Source File: custom_loops_test.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def test_multiple_state_vars(self): x = tf.constant([3.0, 4.0]) y = tf.constant([5.0, 6.0]) z = tf.constant([7.0, 8.0]) alpha = tf.constant(2.0) beta = tf.constant(1.0) with tf.GradientTape(persistent=True) as tape: tape.watch([alpha, beta]) def body(i, state): x, y, z = state k = tf.cast(i + 1, tf.float32) return [x * alpha - beta, y * k * alpha * beta, z * beta + x] out = for_loop(body, [x, y, z], [alpha, beta], 3) with self.subTest("independent_vars"): grad = tape.gradient(out[1], alpha) self.assertAllEqual(792, grad) with self.subTest("dependent_vars"): grad = tape.gradient(out[2], beta) self.assertAllEqual(63, grad)
Example #16
Source File: custom_loops_test.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def test_batching(self): x = tf.constant([[3.0, 4.0], [30.0, 40.0]]) y = tf.constant([[5.0, 6.0], [50.0, 60.0]]) z = tf.constant([[7.0, 8.0], [70.0, 80.0]]) alpha = tf.constant(2.0) beta = tf.constant(1.0) with tf.GradientTape(persistent=True) as tape: tape.watch([alpha, beta]) def body(i, state): x, y, z = state k = tf.cast(i + 1, tf.float32) return [x * alpha - beta, y * k * alpha * beta, z * beta + x] out = for_loop(body, [x, y, z], [alpha, beta], 3) with self.subTest("independent_vars"): grad = tape.gradient(out[1], alpha) self.assertAllEqual(8712, grad) with self.subTest("dependent_vars"): grad = tape.gradient(out[2], beta) self.assertAllEqual(783, grad)
Example #17
Source File: util.py From language with Apache License 2.0 | 6 votes |
def labels_of_top_ranked_predictions_in_batch(labels, predictions): """Applying tf.metrics.mean to this gives precision at 1. Args: labels: minibatch of dense 0/1 labels, shape [batch_size rows, num_classes] predictions: minibatch of predictions of the same shape Returns: one-dimension tensor top_labels, where top_labels[i]=1.0 iff the top-scoring prediction for batch element i has label 1.0 """ indices_of_top_preds = tf.cast(tf.argmax(input=predictions, axis=1), tf.int32) batch_size = tf.reduce_sum(input_tensor=tf.ones_like(indices_of_top_preds)) row_indices = tf.range(batch_size) thresholded_labels = tf.where(labels > 0.0, tf.ones_like(labels), tf.zeros_like(labels)) label_indices_to_gather = tf.transpose( a=tf.stack([row_indices, indices_of_top_preds])) return tf.gather_nd(thresholded_labels, label_indices_to_gather)
Example #18
Source File: imagenet2012_corrupted.py From datasets with Apache License 2.0 | 6 votes |
def _decode_and_center_crop(image_bytes): """Crops to center of image with padding then scales image size.""" shape = tf.image.extract_jpeg_shape(image_bytes) image_height = shape[0] image_width = shape[1] padded_center_crop_size = tf.cast( ((_IMAGE_SIZE / (_IMAGE_SIZE + _CROP_PADDING)) * tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 crop_window = tf.stack([ offset_height, offset_width, padded_center_crop_size, padded_center_crop_size ]) image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = tf.image.resize([image], [_IMAGE_SIZE, _IMAGE_SIZE], method=tf.image.ResizeMethod.BICUBIC)[0] image = tf.cast(image, tf.int32) return image
Example #19
Source File: unbounded_holiday_calendar.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def __init__(self, weekend_mask=None, holidays=None): """Initializer. Args: weekend_mask: Boolean `Tensor` of 7 elements one for each day of the week starting with Monday at index 0. A `True` value indicates the day is considered a weekend day and a `False` value implies a week day. Default value: None which means no weekends are applied. holidays: Defines the holidays that are added to the weekends defined by `weekend_mask`. An instance of `dates.DateTensor` or an object convertible to `DateTensor`. Default value: None which means no holidays other than those implied by the weekends (if any). """ if weekend_mask is not None: weekend_mask = tf.cast(weekend_mask, dtype=tf.bool) if holidays is not None: holidays = dt.convert_to_date_tensor(holidays).ordinal() self._to_biz_space, self._from_biz_space = hol.business_day_mappers( weekend_mask=weekend_mask, holidays=holidays)
Example #20
Source File: agent.py From valan with Apache License 2.0 | 6 votes |
def _neck(self, torso_outputs, state): current_lstm_state, text_enc_outputs = state image_features = tf.cast(torso_outputs[constants.PANO_ENC], tf.float32) lstm_output, next_lstm_state = self._image_encoder(image_features, current_lstm_state) lstm_output = tf.expand_dims(lstm_output, axis=1) # c_text has shape [batch_size, 1, self._text_attention_size] c_text = self._text_attention([ self._text_attention_project_hidden(lstm_output), self._text_attention_project_text(text_enc_outputs) ]) # The next_lstm_state are ListWrappers. In order to make it consistent with # get_initial_state, we convert them to tuple. result_state = [] for one_state in next_lstm_state: result_state.append((one_state[0], one_state[1])) torso_outputs['hidden_state'] = lstm_output torso_outputs['c_text'] = c_text return (torso_outputs, (result_state, text_enc_outputs))
Example #21
Source File: lsm_v2.py From tf-quant-finance with Apache License 2.0 | 6 votes |
def _updated_cashflow(num_times, exercise_index, exercise_value, expected_continuation, cashflow): """Revises the cashflow tensor where options will be exercised earlier.""" do_exercise_bool = exercise_value > expected_continuation do_exercise = tf.cast(do_exercise_bool, exercise_value.dtype) # Shape [num_samples, payoff_dim] scaled_do_exercise = tf.where(do_exercise_bool, exercise_value, tf.zeros_like(exercise_value)) # This picks out the samples where we now wish to exercise. # Shape [num_samples, payoff_dim, 1] new_samp_masked = tf.expand_dims(scaled_do_exercise, axis=2) # This should be one on the current time step and zero otherwise. # This is an array with nonzero entries showing newly exercised payoffs. zeros = tf.zeros_like(cashflow) mask = tf.equal(tf.range(0, num_times), exercise_index - 1) new_cash = tf.where(mask, new_samp_masked, zeros) # Has shape [num_samples, payoff_dim, 1] old_mask = tf.expand_dims(1 - do_exercise, axis=2) mask = tf.range(0, num_times) >= exercise_index old_mask = tf.where(mask, old_mask, zeros) # Shape [num_samples, payoff_dim, num_times] old_cash = old_mask * cashflow return new_cash + old_cash
Example #22
Source File: model.py From trax with Apache License 2.0 | 6 votes |
def __init__(self, hidden_layers, input_size=784, num_classes=10): """Initializes the neural network. Args: hidden_layers: List of ints specifying the sizes of hidden layers. Could be empty. input_size: Length of the input array. The network receives the input image as a flattened 1-d array. Defaults to 784(28*28), the default image size for MNIST. num_classes: The number of output classes. Defaults to 10. """ hidden_layers = [input_size] + hidden_layers + [num_classes] self.weights = [] self.biases = [] for i in range(len(hidden_layers) - 1): # TODO(srbs): This is manually cast to float32 to avoid the cast in # np.dot since backprop fails for tf.cast op. self.weights.append( np.array( np.random.randn(hidden_layers[i + 1], hidden_layers[i]), copy=False, dtype=float32)) self.biases.append( np.array( np.random.randn(hidden_layers[i + 1]), copy=False, dtype=float32))
Example #23
Source File: bounded_holiday_calendar.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _compute_bus_day_ordinals_table(self): """Computes and caches rolled business day ordinals table.""" if self._table_cache.bus_day_ordinals is not None: return self._table_cache.bus_day_ordinals is_bus_day_table = self._compute_is_bus_day_table() with tf.init_scope(): bus_day_ordinals_table = ( tf.cast(tf.where(is_bus_day_table)[:, 0], tf.int32) + self._ordinal_offset - 1) self._table_cache.bus_day_ordinals = bus_day_ordinals_table return bus_day_ordinals_table
Example #24
Source File: daycounts.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def actual_365_fixed(*, start_date, end_date, schedule_info=None, dtype=None, name=None): """Computes the year fraction between the specified dates. The actual/365 convention specifies the year fraction between the start and end date as the actual number of days between the two dates divided by 365. Note that the schedule info is not needed for this convention and is ignored if supplied. For more details see: https://en.wikipedia.org/wiki/Day_count_convention#Actual/365_Fixed Args: start_date: A `DateTensor` object of any shape. end_date: A `DateTensor` object of compatible shape with `start_date`. schedule_info: The schedule info. Ignored for this convention. dtype: The dtype of the result. Either `tf.float32` or `tf.float64`. If not supplied, `tf.float32` is returned. name: Python `str` name prefixed to ops created by this function. If not supplied, `actual_365_fixed` is used. Returns: A real `Tensor` of supplied `dtype` and shape of `start_date`. The year fraction between the start and end date as computed by Actual/365 fixed convention. """ del schedule_info with tf.name_scope(name or 'actual_365_fixed'): end_date = dt.convert_to_date_tensor(end_date) start_date = dt.convert_to_date_tensor(start_date) dtype = dtype or tf.constant(0.).dtype actual_days = tf.cast(start_date.days_until(end_date), dtype=dtype) return actual_days / 365
Example #25
Source File: date_tensor.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def _num_days_in_month(month, year): """Returns number of days in a given month of a given year.""" days_in_months = tf.constant(_DAYS_IN_MONTHS_COMBINED, tf.int32) is_leap = date_utils.is_leap_year(year) return tf.gather(days_in_months, month + 12 * tf.dtypes.cast(is_leap, tf.int32))
Example #26
Source File: imagenet_adversarial.py From armory with MIT License | 5 votes |
def _generate_examples(self, path): """Yields examples.""" clean_key = "clean" adversarial_key = "adversarial" def _parse(serialized_example): ds_features = { "height": tf.io.FixedLenFeature([], tf.int64), "width": tf.io.FixedLenFeature([], tf.int64), "label": tf.io.FixedLenFeature([], tf.int64), "adv-image": tf.io.FixedLenFeature([], tf.string), "clean-image": tf.io.FixedLenFeature([], tf.string), } example = tf.io.parse_single_example(serialized_example, ds_features) img_clean = tf.io.decode_raw(example["clean-image"], tf.float32) img_adv = tf.io.decode_raw(example["adv-image"], tf.float32) # float values are integers in [0.0, 255.0] for clean and adversarial img_clean = tf.cast(img_clean, tf.uint8) img_clean = tf.reshape(img_clean, (example["height"], example["width"], 3)) img_adv = tf.cast(img_adv, tf.uint8) img_adv = tf.reshape(img_adv, (example["height"], example["width"], 3)) return {clean_key: img_clean, adversarial_key: img_adv}, example["label"] ds = tf.data.TFRecordDataset(filenames=[path]) ds = ds.map(lambda x: _parse(x)) default_graph = tf.compat.v1.keras.backend.get_session().graph ds = tfds.as_numpy(ds, graph=default_graph) for i, (img, label) in enumerate(ds): yield str(i), { "images": img, "label": label, }
Example #27
Source File: ncf_keras_main.py From models with Apache License 2.0 | 5 votes |
def metric_fn(logits, dup_mask, match_mlperf): dup_mask = tf.cast(dup_mask, tf.float32) logits = tf.slice(logits, [0, 1], [-1, -1]) in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg( logits, dup_mask, match_mlperf) metric_weights = tf.cast(metric_weights, tf.float32) return in_top_k, metric_weights
Example #28
Source File: bounded_holiday_calendar.py From tf-quant-finance with Apache License 2.0 | 5 votes |
def is_business_day(self, date_tensor): """Returns a tensor of bools for whether given dates are business days.""" is_bus_day_table = self._compute_is_bus_day_table() is_bus_day_int32 = self._gather( is_bus_day_table, date_tensor.ordinal() - self._ordinal_offset + 1) with tf.control_dependencies( self._assert_ordinals_in_bounds(date_tensor.ordinal())): return tf.cast(is_bus_day_int32, dtype=tf.bool)
Example #29
Source File: learner.py From valan with Apache License 2.0 | 5 votes |
def _convert_uint8_to_bfloat16(ts: Any): """Casts uint8 to bfloat16 if input is uint8. Args: ts: any tensor or nested tensor structure, such as EnvOutput. Returns: Converted structure. """ return tf.nest.map_structure( lambda t: tf.cast(t, tf.bfloat16) if t.dtype == tf.uint8 else t, ts)
Example #30
Source File: input_pipeline.py From models with Apache License 2.0 | 5 votes |
def decode_record(record, name_to_features): """Decodes a record to a TensorFlow example.""" example = tf.io.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.cast(t, tf.int32) example[name] = t return example