Python tensorflow.compat.v1.variable_scope() Examples
The following are 30
code examples of tensorflow.compat.v1.variable_scope().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v1
, or try the search function
.
Example #1
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def roc_auc(logits, labels, weights_fn=None): """Calculate ROC AUC. Requires binary classes. Args: logits: Tensor of size [batch_size, 1, 1, num_classes] labels: Tensor of size [batch_size, 1, 1, num_classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: ROC AUC (scalar), weights """ del weights_fn with tf.variable_scope("roc_auc", values=[logits, labels]): predictions = tf.argmax(logits, axis=-1) _, auc = tf.metrics.auc(labels, predictions, curve="ROC") return auc, tf.constant(1.0)
Example #2
Source File: transformer_nat.py From tensor2tensor with Apache License 2.0 | 6 votes |
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name): """Original Transformer decoder.""" with tf.variable_scope(name): targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape( decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
Example #3
Source File: transformer_vae.py From tensor2tensor with Apache License 2.0 | 6 votes |
def compress(x, c, is_2d, hparams, name): """Compress.""" with tf.variable_scope(name): # Run compression by strided convs. cur = x k1 = (3, 3) if is_2d else (3, 1) k2 = (2, 2) if is_2d else (2, 1) cur = residual_conv(cur, hparams.num_compress_steps, k1, hparams, "rc") if c is not None and hparams.do_attend_compress: cur = attend(cur, c, hparams, "compress_attend") for i in range(hparams.num_compress_steps): if hparams.do_residual_compress: cur = residual_conv(cur, hparams.num_compress_steps, k1, hparams, "rc_%d" % i) cur = common_layers.conv_block( cur, hparams.hidden_size, [((1, 1), k2)], strides=k2, name="compress_%d" % i) return cur
Example #4
Source File: rl.py From tensor2tensor with Apache License 2.0 | 6 votes |
def body(self, features): observations = features["inputs_raw"] observations = tf.cast(observations, tf.float32) flat_observations = tf.layers.flatten(observations) with tf.variable_scope("policy"): x = flat_observations for size in self.hparams.policy_layers: x = tf.layers.dense(x, size, activation=tf.nn.relu) logits = tf.layers.dense(x, self.hparams.problem.num_actions) logits = tf.expand_dims(logits, axis=1) with tf.variable_scope("value"): x = flat_observations for size in self.hparams.value_layers: x = tf.layers.dense(x, size, activation=tf.nn.relu) value = tf.layers.dense(x, 1) logits = clip_logits(logits, self.hparams) return {"target_policy": logits, "target_value": value}
Example #5
Source File: nasnet_model.py From benchmarks with Apache License 2.0 | 6 votes |
def _build_aux_head(net, end_points, num_classes, hparams, scope): """Auxiliary head used for all models across all datasets.""" with tf.variable_scope(scope): aux_logits = tf.identity(net) with tf.variable_scope('aux_logits'): aux_logits = slim.avg_pool2d( aux_logits, [5, 5], stride=3, padding='VALID') aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj') aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0') aux_logits = tf.nn.relu(aux_logits) # Shape of feature map before the final layer. shape = aux_logits.shape if hparams.data_format == 'NHWC': shape = shape[1:3] else: shape = shape[2:4] aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID') aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1') aux_logits = tf.nn.relu(aux_logits) aux_logits = contrib_layers.flatten(aux_logits) aux_logits = slim.fully_connected(aux_logits, num_classes) end_points['AuxLogits'] = aux_logits
Example #6
Source File: neural_stack.py From tensor2tensor with Apache License 2.0 | 6 votes |
def _rnn(self, inputs, name, initial_state=None, sequence_length=None): """A helper method to build tf.nn.dynamic_rnn. Args: inputs: The inputs to the RNN. A tensor of shape [batch_size, max_seq_length, embedding_size] name: A namespace for the RNN. initial_state: An optional initial state for the RNN. sequence_length: An optional sequence length for the RNN. Returns: A tf.nn.dynamic_rnn operator. """ layers = [self.cell(layer_size) for layer_size in self._hparams.controller_layer_sizes] with tf.variable_scope(name): return tf.nn.dynamic_rnn( contrib.rnn().MultiRNNCell(layers), inputs, initial_state=initial_state, sequence_length=sequence_length, dtype=tf.float32, time_major=False)
Example #7
Source File: averaged.py From lamb with Apache License 2.0 | 6 votes |
def __init__(self, tensors): tensors = list(tensors) with tf.variable_scope('averaged'): self._num_samples = tf.Variable(0, name='num_samples', trainable=False) with tf.variable_scope('avg'): self._averages = [ tf.get_variable( tensor.name.replace('/', '-').replace(':', '-'), tensor.get_shape(), initializer=tf.zeros_initializer(), trainable=False) for tensor in tensors] with tf.variable_scope('save'): self._saves = [ tf.get_variable( tensor.name.replace('/', '-').replace(':', '-'), tensor.get_shape(), initializer=tf.zeros_initializer(), trainable=False) for tensor in tensors] self._tensors = tensors self._take_sample = self._make_take_sample() self._switch = self._make_swith_to_average() self._restore = self._make_restore() self._reset = self._make_reset()
Example #8
Source File: adv_attack_utils.py From tensor2tensor with Apache License 2.0 | 6 votes |
def fprop(self, x): if x.name in self._logits_dict: return self._logits_dict[x.name] x = tf.map_fn(tf.image.per_image_standardization, x) self._additional_features['inputs'] = x if self._scope is None: scope = tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) else: scope = tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE) with scope: logits = self._model_fn( self._additional_features, None, 'attack', params=self._params, config=self._config) self._logits_dict[x.name] = logits return {model.Model.O_LOGITS: tf.reshape(logits, [-1, logits.shape[-1]])}
Example #9
Source File: utils.py From lamb with Apache License 2.0 | 6 votes |
def layer_norm(x, reduction_indices, epsilon=1e-9, gain=None, bias=None, per_element=True, scope=None): """DOC.""" reduction_indices = ensure_list(reduction_indices) mean = tf.reduce_mean(x, reduction_indices, keep_dims=True) variance = tf.reduce_mean(tf.squared_difference(x, mean), reduction_indices, keep_dims=True) normalized = (x - mean) / tf.sqrt(variance + epsilon) dtype = x.dtype shape = x.get_shape().as_list() for i in six.moves.range(len(shape)): if i not in reduction_indices or not per_element: shape[i] = 1 with tf.variable_scope(scope or 'layer_norm'): if gain is None: gain = tf.get_variable('gain', shape=shape, dtype=dtype, initializer=tf.ones_initializer()) if bias is None: bias = tf.get_variable('bias', shape=shape, dtype=dtype, initializer=tf.zeros_initializer()) return gain*normalized+bias
Example #10
Source File: scheduled_sampling.py From tensor2tensor with Apache License 2.0 | 6 votes |
def loss_fn(self, targets, logits): """Constructs loss dict. Args: targets: [batch_size, seq_len] logits: [batch_size, seq_len, vocab_size] Returns: {str: Tensor of shape []}. Losses. """ batch_size, seq_len, vocab_size = common_layers.shape_list(logits) targets = tf.reshape(targets, [batch_size, seq_len, 1, 1]) logits = tf.reshape(logits, [batch_size, seq_len, 1, 1, vocab_size]) features = copy.copy(self._features) features["targets"] = targets with tf.variable_scope(tf.get_variable_scope(), reuse=True): losses = { "training": self._t2tmodel.loss(logits, features), } return losses
Example #11
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def padded_accuracy_topk(predictions, labels, k, weights_fn=common_layers.weights_nonzero): """Percentage of times that top-k predictions matches labels on non-0s.""" with tf.variable_scope("padded_accuracy_topk", values=[predictions, labels]): padded_predictions, padded_labels = common_layers.pad_with_zeros( predictions, labels) weights = weights_fn(padded_labels) effective_k = tf.minimum(k, common_layers.shape_list(padded_predictions)[-1]) _, outputs = tf.nn.top_k(padded_predictions, k=effective_k) outputs = tf.to_int32(outputs) padded_labels = tf.to_int32(padded_labels) padded_labels = tf.expand_dims(padded_labels, axis=-1) padded_labels += tf.zeros_like(outputs) # Pad to same shape. same = tf.to_float(tf.equal(outputs, padded_labels)) same_topk = tf.reduce_sum(same, axis=-1) return same_topk, weights
Example #12
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def set_precision(predictions, labels, weights_fn=common_layers.weights_nonzero): """Precision of set predictions. Args: predictions : A Tensor of scores of shape [batch, nlabels]. labels: A Tensor of int32s giving true set elements, of shape [batch, seq_length]. weights_fn: A function to weight the elements. Returns: hits: A Tensor of shape [batch, nlabels]. weights: A Tensor of shape [batch, nlabels]. """ with tf.variable_scope("set_precision", values=[predictions, labels]): labels = tf.squeeze(labels, [2, 3]) weights = weights_fn(labels) labels = tf.one_hot(labels, predictions.shape[-1]) labels = tf.reduce_max(labels, axis=1) labels = tf.cast(labels, tf.bool) return tf.to_float(tf.equal(labels, predictions)), weights
Example #13
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def set_recall(predictions, labels, weights_fn=common_layers.weights_nonzero): """Recall of set predictions. Args: predictions : A Tensor of scores of shape [batch, nlabels]. labels: A Tensor of int32s giving true set elements, of shape [batch, seq_length]. weights_fn: A function to weight the elements. Returns: hits: A Tensor of shape [batch, nlabels]. weights: A Tensor of shape [batch, nlabels]. """ with tf.variable_scope("set_recall", values=[predictions, labels]): labels = tf.squeeze(labels, [2, 3]) weights = weights_fn(labels) labels = tf.one_hot(labels, predictions.shape[-1]) labels = tf.reduce_max(labels, axis=1) labels = tf.cast(labels, tf.bool) return tf.to_float(tf.equal(labels, predictions)), weights
Example #14
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def softmax_cross_entropy_one_hot(logits, labels, weights_fn=None): """Calculate softmax cross entropy given one-hot labels and logits. Args: logits: Tensor of size [batch-size, o=1, p=1, num-classes] labels: Tensor of size [batch-size, o=1, p=1, num-classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: cross-entropy (scalar), weights """ with tf.variable_scope("softmax_cross_entropy_one_hot", values=[logits, labels]): del weights_fn cross_entropy = tf.losses.softmax_cross_entropy( onehot_labels=labels, logits=logits) return cross_entropy, tf.constant(1.0)
Example #15
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def sigmoid_accuracy_one_hot(logits, labels, weights_fn=None): """Calculate accuracy for a set, given one-hot labels and logits. Args: logits: Tensor of size [batch-size, o=1, p=1, num-classes] labels: Tensor of size [batch-size, o=1, p=1, num-classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: accuracy (scalar), weights """ with tf.variable_scope("sigmoid_accuracy_one_hot", values=[logits, labels]): del weights_fn predictions = tf.nn.sigmoid(logits) labels = tf.argmax(labels, -1) predictions = tf.argmax(predictions, -1) _, accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions) return accuracy, tf.constant(1.0)
Example #16
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def sigmoid_precision_one_hot(logits, labels, weights_fn=None): """Calculate precision for a set, given one-hot labels and logits. Predictions are converted to one-hot, as predictions[example][arg-max(example)] = 1 Args: logits: Tensor of size [batch-size, o=1, p=1, num-classes] labels: Tensor of size [batch-size, o=1, p=1, num-classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: precision (scalar), weights """ with tf.variable_scope("sigmoid_precision_one_hot", values=[logits, labels]): del weights_fn num_classes = logits.shape[-1] predictions = tf.nn.sigmoid(logits) predictions = tf.argmax(predictions, -1) predictions = tf.one_hot(predictions, num_classes) _, precision = tf.metrics.precision(labels=labels, predictions=predictions) return precision, tf.constant(1.0)
Example #17
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def sigmoid_recall_one_hot(logits, labels, weights_fn=None): """Calculate recall for a set, given one-hot labels and logits. Predictions are converted to one-hot, as predictions[example][arg-max(example)] = 1 Args: logits: Tensor of size [batch-size, o=1, p=1, num-classes] labels: Tensor of size [batch-size, o=1, p=1, num-classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: recall (scalar), weights """ with tf.variable_scope("sigmoid_recall_one_hot", values=[logits, labels]): del weights_fn num_classes = logits.shape[-1] predictions = tf.nn.sigmoid(logits) predictions = tf.argmax(predictions, -1) predictions = tf.one_hot(predictions, num_classes) _, recall = tf.metrics.recall(labels=labels, predictions=predictions) return recall, tf.constant(1.0)
Example #18
Source File: metrics.py From tensor2tensor with Apache License 2.0 | 6 votes |
def sigmoid_cross_entropy_one_hot(logits, labels, weights_fn=None): """Calculate sigmoid cross entropy for one-hot lanels and logits. Args: logits: Tensor of size [batch-size, o=1, p=1, num-classes] labels: Tensor of size [batch-size, o=1, p=1, num-classes] weights_fn: Function that takes in labels and weighs examples (unused) Returns: cross_entropy (scalar), weights """ with tf.variable_scope("sigmoid_cross_entropy_one_hot", values=[logits, labels]): del weights_fn cross_entropy = tf.losses.sigmoid_cross_entropy( multi_class_labels=labels, logits=logits) return cross_entropy, tf.constant(1.0)
Example #19
Source File: transformer_nat.py From tensor2tensor with Apache License 2.0 | 5 votes |
def residual_conv(x, repeat, k, hparams, name, reuse=None): """A stack of convolution blocks with residual connections.""" with tf.variable_scope(name, reuse=reuse): dilations_and_kernels = [((1, 1), k) for _ in range(3)] for i in range(repeat): with tf.variable_scope("repeat_%d" % i): y = common_layers.conv_block( common_layers.layer_norm(x, hparams.hidden_size, name="lnorm"), hparams.hidden_size, dilations_and_kernels, padding="SAME", name="residual_conv") y = tf.nn.dropout(y, 1.0 - hparams.dropout) x += y return x
Example #20
Source File: t2t_model.py From tensor2tensor with Apache License 2.0 | 5 votes |
def top(self, body_output, features): """Computes logits given body output and features. Args: body_output: dict of str to Tensor, comprising one key-value pair for each target. Each value denotes the target's pre-logit activations. Alternatively, it may be a single Tensor denoting the pre-logits for that target. features: dict of str to Tensor. Typically it is the preprocessed data batch after Problem's preprocess_example(). Returns: logits: dict of str to Tensor, denoting each logits for each target; or a single Tensor denoting the logits for that target. When targets are generated at training time: logits == { "self_generated_targets": <generated targets tensor> "logits": <original logits Tensor or dict> } """ if isinstance(body_output, dict): logits = {} for k, v in six.iteritems(body_output): # TODO(aidangomez): share variables here? with tf.variable_scope(k) as top_vs: self._add_variable_scope("top_%s" % k, top_vs) logits[k] = self._top_single(v, k, features) return logits else: return self._top_single(body_output, "targets", features)
Example #21
Source File: t2t_model.py From tensor2tensor with Apache License 2.0 | 5 votes |
def model_fn(self, features): with tf.variable_scope(tf.get_variable_scope(), use_resource=True) as vs: self._add_variable_scope("model_fn", vs) transformed_features = self.bottom(features) if self.hparams.activation_dtype == "bfloat16": for k, v in sorted(six.iteritems(transformed_features)): if v.dtype == tf.float32: transformed_features[k] = tf.cast(v, tf.bfloat16) with tf.variable_scope("body") as body_vs: self._add_variable_scope("body", body_vs) log_info("Building model body") body_out = self.body(transformed_features) output, losses = self._normalize_body_output(body_out) if "training" in losses: log_info("Skipping T2TModel top and loss because training loss " "returned from body") logits = output else: logits = self.top(output, features) losses["training"] = 0.0 if (self._hparams.mode != tf.estimator.ModeKeys.PREDICT and self._hparams.mode != "attack"): losses["training"] = self.loss(logits, features) return logits, losses
Example #22
Source File: expert_utils.py From tensor2tensor with Apache License 2.0 | 5 votes |
def add_var_scope(scope=None): return add_scope(scope, scope_fn=tf.variable_scope)
Example #23
Source File: transformer.py From tensor2tensor with Apache License 2.0 | 5 votes |
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha, use_tpu=False): """Beam search decoding. Args: features: an map of string to `Tensor` decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. use_tpu: A bool, whether to do beam decode on TPU. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } """ if (self._hparams.self_attention_type not in [ "dot_product", "dot_product_relative" ]): # Caching is not guaranteed to work with attention types other than # dot_product and dot_product_relative. return self._beam_decode_slow(features, decode_length, beam_size, top_beams, alpha, use_tpu) with tf.variable_scope(self.name): if use_tpu: return self._fast_decode_tpu(features, decode_length, beam_size, top_beams, alpha) return self._fast_decode(features, decode_length, beam_size, top_beams, alpha)
Example #24
Source File: transformer_nat.py From tensor2tensor with Apache License 2.0 | 5 votes |
def decompress_step(source, hparams, first_relu, name): """Decompression function.""" with tf.variable_scope(name): shape = common_layers.shape_list(source) multiplier = 2 kernel = (1, 1) thicker = common_layers.conv_block( source, hparams.hidden_size * multiplier, [((1, 1), kernel)], first_relu=first_relu, name="decompress_conv") return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size])
Example #25
Source File: transformer_nat.py From tensor2tensor with Apache License 2.0 | 5 votes |
def encode(x, x_space, hparams, name): """Transformer preparations and encoder.""" with tf.variable_scope(name): (encoder_input, encoder_self_attention_bias, ed) = transformer.transformer_prepare_encoder(x, x_space, hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) return transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams), ed
Example #26
Source File: transformer_nat.py From tensor2tensor with Apache License 2.0 | 5 votes |
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams): """Sample from the latent space in the autoencoder.""" def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(ids, axis=2) # Ids start with added all-zeros. latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]]) with tf.variable_scope(tf.get_variable_scope(), reuse=False): latents_dense = embed( tf.one_hot(latents_discrete, depth=2**hparams.bottleneck_bits)) latents_pred = decode_transformer(inputs, ed, latents_dense, hparams, "extra") logits = tf.layers.dense( latents_pred, 2**hparams.bottleneck_bits, name="extra_logits") current_output_position = common_layers.shape_list(ids)[1] - 1 logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1]) initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32) length = tf.shape(latents_dense_in)[1] ids, _, _ = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size=1, decode_length=length, vocab_size=2**hparams.bottleneck_bits, alpha=0.0, eos_id=-1, stop_early=False) res = tf.expand_dims(ids[:, 0, :], axis=2) # Pick first beam. return res[:, 1:] # Remove the added all-zeros from ids.
Example #27
Source File: transformer_nat.py From tensor2tensor with Apache License 2.0 | 5 votes |
def body(self, features): inputs = features["inputs"] if "inputs" in features else None reuse = "cache_raw" in features with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): res, loss, _ = ae_transformer_internal( inputs, features["targets"], features["target_space_id"], self._hparams, features.get("cache_raw", None)) return res, loss
Example #28
Source File: transformer_nat.py From tensor2tensor with Apache License 2.0 | 5 votes |
def prepare_features_for_infer(self, features): batch_size = self._decode_hparams.batch_size inputs = tf.zeros([batch_size, 1, 1, self._hparams.hidden_size]) inputs = inputs if "inputs" in features else None targets = tf.zeros([batch_size, 1, 1, self._hparams.hidden_size]) with tf.variable_scope("transformer_nat/body"): _, _, cache = ae_transformer_internal( inputs, targets, features["target_space_id"], self._hparams) features["cache_raw"] = cache
Example #29
Source File: rl.py From tensor2tensor with Apache License 2.0 | 5 votes |
def body(self, features): observations = features["inputs_raw"] # Axis 0 - Batch. # Axis 1 - Input Frames, 4 frames. # Axis 2, 3 - Height & Width. # Axis 4 - Channels RGB, 3 colours. x = tf.transpose(observations, [0, 2, 3, 1, 4]) x_shape = common_layers.shape_list(x) x = tf.reshape(x, x_shape[:-2] + [-1]) dropout = getattr(self.hparams, "dropout_ppo", 0.0) with tf.variable_scope("feed_forward_cnn_small"): x = tf.cast(x, tf.float32) / 255.0 x = tf.layers.conv2d(x, 32, (5, 5), strides=(2, 2), activation=tf.nn.relu, padding="same") x = tf.layers.conv2d(x, 32, (5, 5), strides=(2, 2), activation=tf.nn.relu, padding="same") flat_x = tf.layers.flatten(x) if self.use_epochs: epoch = features["epoch"] + tf.zeros([x_shape[0]], dtype=tf.int32) # Randomly set epoch to 0 in some cases as that's the inference value. rand = tf.random.uniform([x_shape[0]]) epoch = tf.where(rand < 0.1, tf.zeros_like(epoch), epoch) # Embed the epoch number. emb_epoch = common_layers.embedding(epoch, 32, 32) # [batch, 32] flat_x = tf.concat([flat_x, emb_epoch], axis=1) flat_x = tf.layers.dropout(flat_x, rate=dropout) x = tf.layers.dense(flat_x, 128, activation=tf.nn.relu) logits = tf.layers.dense( x, self.hparams.problem.num_actions, name="dense2" ) logits = clip_logits(logits, self.hparams) logits = tf.expand_dims(logits, axis=1) value = tf.layers.dense(x, self.distributional_value_size) return {"target_policy": logits, "target_value": value}
Example #30
Source File: scheduled_sampling.py From tensor2tensor with Apache License 2.0 | 5 votes |
def infer_fn(self, partial_targets): """Computes logits for all timesteps. Args: partial_targets: [batch_size, seq_len]. Targets to condition on. Returns: next_token_logits: [batch_size, seq_len, vocab_size] """ batch_size, seq_len = common_layers.shape_list(partial_targets) partial_targets = tf.reshape(partial_targets, [batch_size, seq_len, 1, 1]) features = copy.copy(self._features) features["targets"] = partial_targets with tf.variable_scope(tf.get_variable_scope(), reuse=True): transformed_features = self._t2tmodel.bottom(features) with tf.variable_scope("body"): body_outputs, losses = self._t2tmodel._normalize_body_output( # pylint: disable=protected-access self._t2tmodel.body(transformed_features)) assert losses == {"extra": 0.0}, ( "Auxiliary losses are not propagated in this code. %s" % (losses,)) logits = self._t2tmodel.top(body_outputs, features) vocab_size = self._t2tmodel.problem_hparams.vocab_size["targets"] logits = tf.reshape(logits, [batch_size, seq_len, vocab_size]) return logits