Python tensorflow.compat.v1.expand_dims() Examples
The following are 30
code examples of tensorflow.compat.v1.expand_dims().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v1
, or try the search function
.
Example #1
Source File: bytenet.py From tensor2tensor with Apache License 2.0 | 6 votes |
def bytenet_internal(inputs, targets, hparams): """ByteNet, main step used for training.""" with tf.variable_scope("bytenet"): # Flatten inputs and extend length by 50%. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1])) inputs_shape = inputs.shape.as_list() inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]]) inputs_shape[1] = None inputs.set_shape(inputs_shape) # Don't lose the other shapes when padding. # Pad inputs and targets to be the same length, divisible by 50. inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=50) final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat, "SAME", "encoder", hparams) shifted_targets = common_layers.shift_right(targets) kernel = (hparams.kernel_height, hparams.kernel_width) decoder_start = common_layers.conv_block( tf.concat([final_encoder, shifted_targets], axis=3), hparams.hidden_size, [((1, 1), kernel)], padding="LEFT") return residual_dilated_conv(decoder_start, hparams.num_block_repeat, "LEFT", "decoder", hparams)
Example #2
Source File: metrics_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def testSigmoidRecallOneHot(self): logits = np.array([ [-1., 1.], [1., -1.], [1., -1.], [1., -1.] ]) labels = np.array([ [0, 1], [0, 1], [0, 1], [0, 1] ]) logits = np.expand_dims(np.expand_dims(logits, 1), 1) labels = np.expand_dims(np.expand_dims(labels, 1), 1) with self.test_session() as session: score, _ = metrics.sigmoid_recall_one_hot(logits, labels) session.run(tf.global_variables_initializer()) session.run(tf.local_variables_initializer()) s = session.run(score) self.assertEqual(s, 0.25)
Example #3
Source File: metrics_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def testSigmoidPrecisionOneHot(self): logits = np.array([ [-1., 1.], [1., -1.], [1., -1.], [1., -1.] ]) labels = np.array([ [0, 1], [0, 1], [0, 1], [0, 1] ]) logits = np.expand_dims(np.expand_dims(logits, 1), 1) labels = np.expand_dims(np.expand_dims(labels, 1), 1) with self.test_session() as session: score, _ = metrics.sigmoid_precision_one_hot(logits, labels) session.run(tf.global_variables_initializer()) session.run(tf.local_variables_initializer()) s = session.run(score) self.assertEqual(s, 0.25)
Example #4
Source File: metrics_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def testSigmoidAccuracyOneHot(self): logits = np.array([ [-1., 1.], [1., -1.], [-1., 1.], [1., -1.] ]) labels = np.array([ [0, 1], [1, 0], [1, 0], [0, 1] ]) logits = np.expand_dims(np.expand_dims(logits, 1), 1) labels = np.expand_dims(np.expand_dims(labels, 1), 1) with self.test_session() as session: score, _ = metrics.sigmoid_accuracy_one_hot(logits, labels) session.run(tf.global_variables_initializer()) session.run(tf.local_variables_initializer()) s = session.run(score) self.assertEqual(s, 0.5)
Example #5
Source File: metrics_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def testPrefixAccuracy(self): vocab_size = 10 predictions = tf.one_hot( tf.constant([[[1], [2], [3], [4], [9], [6], [7], [8]], [[1], [2], [3], [4], [5], [9], [7], [8]], [[1], [2], [3], [4], [5], [9], [7], [0]]]), vocab_size) labels = tf.expand_dims( tf.constant([[[1], [2], [3], [4], [5], [6], [7], [8]], [[1], [2], [3], [4], [5], [6], [7], [8]], [[1], [2], [3], [4], [5], [6], [7], [0]]]), axis=-1) expected_accuracy = np.average([4.0 / 8.0, 5.0 / 8.0, 5.0 / 7.0]) accuracy, _ = metrics.prefix_accuracy(predictions, labels) with self.test_session() as session: accuracy_value = session.run(accuracy) self.assertAlmostEqual(expected_accuracy, accuracy_value)
Example #6
Source File: utils.py From lamb with Apache License 2.0 | 6 votes |
def mask_from_lengths(lengths, max_length=None, dtype=None, name=None): """Convert a length scalar to a vector of binary masks. This function will convert a vector of lengths to a matrix of binary masks. E.g. [2, 4, 3] will become [[1, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 0]] Args: lengths: a d-dimensional vector of integers corresponding to lengths. max_length: an optional (default: None) scalar-like or 0-dimensional tensor indicating the maximum length of the masks. If not provided, the maximum length will be inferred from the lengths vector. dtype: the dtype of the returned mask, if specified. If None, the dtype of the lengths will be used. name: a name for the operation (optional). Returns: A d x max_length tensor of binary masks (int32). """ with tf.name_scope(name, 'mask_from_lengths'): dtype = lengths.dtype if dtype is None else dtype max_length = tf.reduce_max(lengths) if max_length is None else max_length indexes = tf.range(max_length, dtype=lengths.dtype) mask = tf.less(tf.expand_dims(indexes, 0), tf.expand_dims(lengths, 1)) cast_mask = tf.cast(mask, dtype) return tf.stop_gradient(cast_mask)
Example #7
Source File: tf_atari_wrappers.py From tensor2tensor with Apache License 2.0 | 6 votes |
def simulate(self, action): reward, done = self._batch_env.simulate(action) with tf.control_dependencies([reward, done]): new_observ = tf.expand_dims(self._batch_env.observ, axis=1) # If we shouldn't stack, i.e. self.history == 1, then just assign # new_observ to self._observ and return from here. if self.history == 1: with tf.control_dependencies([self._observ.assign(new_observ)]): return tf.identity(reward), tf.identity(done) # If we should stack, then do the required work. old_observ = tf.gather( self._observ.read_value(), list(range(1, self.history)), axis=1) with tf.control_dependencies([new_observ, old_observ]): with tf.control_dependencies([self._observ.assign( tf.concat([old_observ, new_observ], axis=1))]): return tf.identity(reward), tf.identity(done)
Example #8
Source File: metrics_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def testSigmoidCrossEntropyOneHot(self): logits = np.array([ [-1., 1.], [1., -1.], [1., -1.], [1., -1.] ]) labels = np.array([ [0, 1], [1, 0], [0, 0], [0, 1] ]) logits = np.expand_dims(np.expand_dims(logits, 1), 1) labels = np.expand_dims(np.expand_dims(labels, 1), 1) with self.test_session() as session: score, _ = metrics.sigmoid_cross_entropy_one_hot(logits, labels) session.run(tf.global_variables_initializer()) session.run(tf.local_variables_initializer()) s = session.run(score) self.assertAlmostEqual(s, 0.688, places=3)
Example #9
Source File: expert_utils.py From tensor2tensor with Apache License 2.0 | 6 votes |
def combine(self, expert_out, multiply_by_gates=True): """Sum together the expert output, weighted by the gates. The slice corresponding to a particular batch element `b` is computed as the sum over all experts `i` of the expert output, weighted by the corresponding gate values. If `multiply_by_gates` is set to False, the gate values are ignored. Args: expert_out: a list of `num_experts` `Tensor`s, each with shape `[expert_batch_size_i, <extra_output_dims>]`. multiply_by_gates: a boolean Returns: a `Tensor` with shape `[batch_size, <extra_output_dims>]`. """ # see comments on convert_gradient_to_tensor stitched = common_layers.convert_gradient_to_tensor( tf.concat(expert_out, 0)) if multiply_by_gates: stitched *= tf.expand_dims(self._nonzero_gates, 1) combined = tf.unsorted_segment_sum(stitched, self._batch_index, tf.shape(self._gates)[0]) return combined
Example #10
Source File: expert_utils.py From tensor2tensor with Apache License 2.0 | 6 votes |
def combine(self, x): """Return the output from the experts. When one example goes to multiple experts, the outputs are summed. Args: x: a Tensor with shape [batch, num_experts, expert_capacity, depth] Returns: a `Tensor` with shape `[batch, length, depth] """ depth = tf.shape(x)[-1] x *= tf.expand_dims(self._nonpadding, -1) ret = tf.unsorted_segment_sum( x, self._flat_indices, num_segments=self._batch * self._length) ret = tf.reshape(ret, [self._batch, self._length, depth]) return ret
Example #11
Source File: data_reader.py From tensor2tensor with Apache License 2.0 | 6 votes |
def standardize_shapes(features, batch_size=None): """Set the right shapes for the features.""" for fname in ["inputs", "targets"]: if fname not in features: continue f = features[fname] while len(f.get_shape()) < 4: f = tf.expand_dims(f, axis=-1) features[fname] = f if batch_size: # Ensure batch size is set on all features for _, t in six.iteritems(features): shape = t.get_shape().as_list() shape[0] = batch_size t.set_shape(t.get_shape().merge_with(shape)) # Assert shapes are fully known t.get_shape().assert_is_fully_defined() return features
Example #12
Source File: transformer.py From tensor2tensor with Apache License 2.0 | 6 votes |
def body(self, features): hparams = self._hparams inputs = features["inputs"] target_space = features["target_space_id"] inputs = common_layers.flatten4d3d(inputs) (encoder_input, encoder_self_attention_bias, _) = ( transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=features_to_nonpadding(features, "inputs")) encoder_output = tf.expand_dims(encoder_output, 2) return encoder_output
Example #13
Source File: neural_stack.py From tensor2tensor with Apache License 2.0 | 6 votes |
def initialize_write_strengths(self, batch_size): """Initialize write strengths to write to the first memory address. This is exposed as its own function so that it can be overridden to provide alternate write adressing schemes. Args: batch_size: The size of the current batch. Returns: A tf.float32 tensor of shape [num_write_heads, memory_size, 1] where the first element in the second dimension is set to 1.0. """ return tf.expand_dims( tf.one_hot([[0] * self._num_write_heads] * batch_size, depth=self._memory_size, dtype=tf.float32), axis=3)
Example #14
Source File: neural_stack.py From tensor2tensor with Apache License 2.0 | 6 votes |
def get_read_mask(self, read_head_index): """Creates a mask which allows us to attenuate subsequent read strengths. This is exposed as its own function so that it can be overridden to provide alternate read adressing schemes. Args: read_head_index: Identifies which read head we're getting the mask for. Returns: A tf.float32 tensor of shape [1, 1, memory_size, memory_size] """ if read_head_index == 0: return tf.expand_dims( common_layers.mask_pos_lt(self._memory_size, self._memory_size), axis=0) else: raise ValueError("Read head index must be 0 for stack.")
Example #15
Source File: neural_stack.py From tensor2tensor with Apache License 2.0 | 6 votes |
def initialize_write_strengths(self, batch_size): """Initialize write strengths which write in both directions. Unlike in Grefenstette et al., It's writing out from the center of the memory so that it doesn't need to shift the entire memory forward at each step. Args: batch_size: The size of the current batch. Returns: A tf.float32 tensor of shape [num_write_heads, memory_size, 1]. """ memory_center = self._memory_size // 2 return tf.expand_dims( tf.concat([ # The write strength for the deque bottom. # Should be shifted back at each timestep. tf.one_hot([[memory_center - 1]] * batch_size, depth=self._memory_size, dtype=tf.float32), # The write strength for the deque top. # Should be shifted forward at each timestep. tf.one_hot([[memory_center]] * batch_size, depth=self._memory_size, dtype=tf.float32) ], axis=1), axis=3)
Example #16
Source File: transformer_nat.py From tensor2tensor with Apache License 2.0 | 6 votes |
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name): """Original Transformer decoder.""" with tf.variable_scope(name): targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape( decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
Example #17
Source File: rl.py From tensor2tensor with Apache License 2.0 | 6 votes |
def body(self, features): observations = features["inputs_raw"] observations = tf.cast(observations, tf.float32) flat_observations = tf.layers.flatten(observations) with tf.variable_scope("policy"): x = flat_observations for size in self.hparams.policy_layers: x = tf.layers.dense(x, size, activation=tf.nn.relu) logits = tf.layers.dense(x, self.hparams.problem.num_actions) logits = tf.expand_dims(logits, axis=1) with tf.variable_scope("value"): x = flat_observations for size in self.hparams.value_layers: x = tf.layers.dense(x, size, activation=tf.nn.relu) value = tf.layers.dense(x, 1) logits = clip_logits(logits, self.hparams) return {"target_policy": logits, "target_value": value}
Example #18
Source File: attention_lm_moe.py From tensor2tensor with Apache License 2.0 | 6 votes |
def remove_pad(x, pad_remover, mode): """Remove padding by concatenating all dimension into one. Args: x (tf.Tensor): input of shape [batch_size, length, depth] pad_remover (obj): a PadRemover object mode (ModeKeys): infer, train or eval. If inference, the padding remover is not applied Returns: tf.Tensor of shape [1,length_nonpad,depth] where length_nonpad <= batch_size*length """ # Concatenate all tokens (without padding) x = expert_utils.flatten_all_but_last(x) # Remove padding for training and eval if mode != ModeKeys.PREDICT: # This is a hack to allows inference when the <go> token # is detected as padding and removed. This works for now because there is # no padding at inference. x = pad_remover.remove(x) x = tf.expand_dims(x, axis=0) # Now batch_size=1 return x
Example #19
Source File: autoencoders.py From tensor2tensor with Apache License 2.0 | 6 votes |
def infer(self, features, *args, **kwargs): # pylint: disable=arguments-differ """Produce predictions from the model by sampling.""" del args, kwargs # Inputs and features preparation needed to handle edge cases. if not features: features = {} inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) # Sample and decode. num_channels = self.num_channels if "targets" not in features: features["targets"] = tf.zeros( [self.hparams.batch_size, 1, 1, num_channels], dtype=tf.int32) logits, _ = self(features) # pylint: disable=not-callable samples = tf.argmax(logits, axis=-1) # Restore inputs to not confuse Estimator in edge cases. if inputs_old is not None: features["inputs"] = inputs_old # Return samples. return samples
Example #20
Source File: attention_lm.py From tensor2tensor with Apache License 2.0 | 6 votes |
def body(self, features): # Remove dropout if not training hparams = self._hparams targets = features["targets"] targets = tf.squeeze(targets, 2) (decoder_input, decoder_self_attention_bias) = attention_lm_prepare_decoder( targets, hparams) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = attention_lm_decoder(decoder_input, decoder_self_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) return decoder_output
Example #21
Source File: vqa_attention.py From tensor2tensor with Apache License 2.0 | 6 votes |
def attn(image_feat, query, hparams, name="attn"): """Attention on image feature with question as query.""" with tf.variable_scope(name, "attn", values=[image_feat, query]): attn_dim = hparams.attn_dim num_glimps = hparams.num_glimps num_channels = common_layers.shape_list(image_feat)[-1] if len(common_layers.shape_list(image_feat)) == 4: image_feat = common_layers.flatten4d3d(image_feat) query = tf.expand_dims(query, 1) image_proj = common_attention.compute_attention_component( image_feat, attn_dim, name="image_proj") query_proj = common_attention.compute_attention_component( query, attn_dim, name="query_proj") h = tf.nn.relu(image_proj + query_proj) h_proj = common_attention.compute_attention_component( h, num_glimps, name="h_proj") p = tf.nn.softmax(h_proj, axis=1) image_ave = tf.matmul(image_feat, p, transpose_a=True) image_ave = tf.reshape(image_ave, [-1, num_channels*num_glimps]) return image_ave
Example #22
Source File: lstm.py From tensor2tensor with Apache License 2.0 | 6 votes |
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train, inputs_length, targets_length): """LSTM seq2seq model with attention, main step used for training.""" with tf.variable_scope("lstm_seq2seq_attention"): # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1) encoder_outputs, final_encoder_state = lstm( inputs, inputs_length, hparams, train, "encoder") # LSTM decoder with attention. shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = targets_length + 1 decoder_outputs = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", final_encoder_state, encoder_outputs, inputs_length, targets_length) return tf.expand_dims(decoder_outputs, axis=2)
Example #23
Source File: lstm.py From tensor2tensor with Apache License 2.0 | 6 votes |
def lstm_seq2seq_internal_attention_bid_encoder(inputs, targets, hparams, train): """LSTM seq2seq model with attention, main step used for training.""" with tf.variable_scope("lstm_seq2seq_attention_bid_encoder"): inputs_length = common_layers.length_from_embedding(inputs) # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. encoder_outputs, final_encoder_state = lstm_bid_encoder( inputs, inputs_length, hparams, train, "encoder") # LSTM decoder with attention shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = common_layers.length_from_embedding(shifted_targets) + 1 hparams_decoder = copy.copy(hparams) hparams_decoder.hidden_size = 2 * hparams.hidden_size decoder_outputs = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams_decoder, train, "decoder", final_encoder_state, encoder_outputs, inputs_length, targets_length) return tf.expand_dims(decoder_outputs, axis=2)
Example #24
Source File: neural_assistant.py From tensor2tensor with Apache License 2.0 | 6 votes |
def compute_max_pool_embedding(input_embeddings, input_lengths): """Computes max pool embedding. Args: input_embeddings: <tf.float32>[bs, max_seq_len, emb_dim] input_lengths: <tf.int64>[bs, 1] Returns: max_pool_embedding: <tf.float32>[bs, emb_dim] """ max_seq_len = tf.shape(input_embeddings)[1] # <tf.float32>[bs, max_seq_len] mask = 1.0 - tf.sequence_mask(input_lengths, max_seq_len, dtype=tf.float32) mask = tf.squeeze(mask * (-1e-6), 1) mask = tf.expand_dims(mask, 2) # <tf.float32>[bs, emb_dim] max_pool_embedding = tf.reduce_max(input_embeddings + mask, 1) # <tf.float32>[bs, dim] return max_pool_embedding
Example #25
Source File: neural_assistant.py From tensor2tensor with Apache License 2.0 | 6 votes |
def compute_average_embedding(input_embeddings, input_lengths): """Computes bag-of-words embedding. Args: input_embeddings: <tf.float32>[bs, max_seq_len, emb_dim] input_lengths: <tf.int64>[bs, 1] Returns: bow_embedding: <tf.float32>[bs, emb_dim] """ max_seq_len = tf.shape(input_embeddings)[1] # <tf.float32>[bs, 1, max_seq_len] mask = tf.sequence_mask(input_lengths, max_seq_len, dtype=tf.float32) # <tf.float32>[bs, 1, emb_dim] sum_embedding = tf.matmul(mask, input_embeddings) # <tf.float32>[bs, 1, emb_dim] avg_embedding = sum_embedding / tf.to_float(tf.expand_dims(input_lengths, 2)) # <tf.float32>[bs, dim] return tf.squeeze(avg_embedding, 1)
Example #26
Source File: nas_layers.py From tensor2tensor with Apache License 2.0 | 6 votes |
def _apply_logic(self, input_tensor, output_depth, hparams, var_scope_suffix, nonpadding, mask_future, **unused_kwargs): """Applies conv logic to `input_tensor`.""" with tf.variable_scope("%s_conv_%s" % (self._conv_type, var_scope_suffix)): if mask_future: # Pad shift the inputs so that temporal information does not leak. This # must be used in tandem with VALID padding. pad_amount = int(self._conv_width - 1) * self._dilation_rate logic_output = tf.pad( input_tensor, paddings=[[0, 0], [pad_amount, 0], [0, 0]]) padding = "VALID" else: logic_output = input_tensor padding = "SAME" logic_output = tf.expand_dims(logic_output, 2) logic_output = self._conv_function(logic_output, output_depth, padding) logic_output = tf.squeeze(logic_output, 2) return logic_output
Example #27
Source File: preprocessing.py From benchmarks with Apache License 2.0 | 6 votes |
def _distort_image(self, image): """Distort one image for training a network. Adopted the standard data augmentation scheme that is widely used for this dataset: the images are first zero-padded with 4 pixels on each side, then randomly cropped to again produce distorted images; half of the images are then horizontally mirrored. Args: image: input image. Returns: distorted image. """ image = tf.image.resize_image_with_crop_or_pad( image, self.height + 8, self.width + 8) distorted_image = tf.random_crop(image, [self.height, self.width, self.depth]) # Randomly flip the image horizontally. distorted_image = tf.image.random_flip_left_right(distorted_image) if self.summary_verbosity >= 3: tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) return distorted_image
Example #28
Source File: vqa_self_attention.py From tensor2tensor with Apache License 2.0 | 5 votes |
def attn(image_feat, query, hparams, name="attn", save_weights_to=None, make_image_summary=True): """Attention on image feature with question as query.""" with tf.variable_scope(name, "attn", values=[image_feat, query]): total_key_depth = hparams.attention_key_channels or hparams.hidden_size total_value_depth = hparams.attention_value_channels or hparams.hidden_size num_heads = hparams.num_heads query = tf.expand_dims(query, 1) q, k, v = common_attention.compute_qkv( query, image_feat, total_key_depth, total_value_depth, ) q = common_attention.split_heads(q, num_heads) k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) if hparams.scale_dotproduct: key_depth_per_head = total_key_depth // num_heads q *= key_depth_per_head**-0.5 # image_feat is input as v x = common_attention.dot_product_attention( q, k, v, None, dropout_rate=hparams.attention_dropout, image_shapes=None, save_weights_to=save_weights_to, make_image_summary=make_image_summary) x = common_attention.combine_heads(x) return tf.squeeze(x, axis=1)
Example #29
Source File: utils.py From lamb with Apache License 2.0 | 5 votes |
def expand_tile(tensor, n, name=None): """Returns a tensor repeated n times along a newly added first dimension.""" with tf.name_scope(name, 'expand_tile'): n_ = tf.reshape(n, [1]) num_dims = len(tensor.get_shape().as_list()) multiples = tf.concat([n_, tf.ones([num_dims], dtype=tf.int32)], axis=0) # multiples = [n, 1, 1, ..., 1] res = tf.tile(tf.expand_dims(tensor, 0), multiples) first_dim = None if isinstance(n, int): first_dim = n res.set_shape([first_dim] + tensor.get_shape().as_list()) return res
Example #30
Source File: residual_shuffle_exchange.py From tensor2tensor with Apache License 2.0 | 5 votes |
def body(self, features): """Body of Residual Shuffle-Exchange network. Args: features: dictionary of inputs and targets Returns: the network output. """ inputs = tf.squeeze(features["inputs"], axis=2) logits = residual_shuffle_network(inputs, self._hparams) return tf.expand_dims(logits, axis=2)