Python tensorflow.contrib.seq2seq.AttentionWrapper() Examples
The following are 23
code examples of tensorflow.contrib.seq2seq.AttentionWrapper().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.contrib.seq2seq
, or try the search function
.
Example #1
Source File: seq2seq_decoder_estimator.py From icecaps with MIT License | 6 votes |
def build_attention_wrapper(self, final_cell): self.feedforward_inputs = tf.cond( self.beam_search_decoding, lambda: seq2seq.tile_batch( self.features["inputs"], multiplier=self.hparams.beam_width), lambda: self.features["inputs"]) self.feedforward_inputs_length = tf.cond( self.beam_search_decoding, lambda: seq2seq.tile_batch( self.features["length"], multiplier=self.hparams.beam_width), lambda: self.features["length"]) attention_mechanism = self.build_attention_mechanism() return AttentionWrapper( cell=final_cell, attention_mechanism=attention_mechanism, attention_layer_size=self.hparams.hidden_units, cell_input_fn=self._attention_input_feeding, initial_cell_state=self.initial_state[-1] if self.hparams.depth > 1 else self.initial_state)
Example #2
Source File: attention_predictor.py From aster with MIT License | 5 votes |
def _build_decoder_cell(self, feature_maps): attention_mechanism = self._build_attention_mechanism(feature_maps) wrapper_class = seq2seq.AttentionWrapper if not self._sync else sync_attention_wrapper.SyncAttentionWrapper attention_cell = wrapper_class( self._rnn_cell, attention_mechanism, output_attention=False) if not self._lm_rnn_cell: decoder_cell = attention_cell else: decoder_cell = ConcatOutputMultiRNNCell([attention_cell, self._lm_rnn_cell]) return decoder_cell
Example #3
Source File: module.py From self-attention-tacotron with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self, cell, mgc_prenets: Tuple[PreNet], lf0_prenets: Tuple[PreNet], attention_mechanism1, attention_mechanism2, trainable=True, name=None, **kwargs): super(DualSourceMgcLf0AttentionRNN, self).__init__(name=name, trainable=trainable, **kwargs) attention_cell = AttentionWrapper( cell, [attention_mechanism1, attention_mechanism2], alignment_history=True, output_attention=False) prenet_cell = DecoderMgcLf0PreNetWrapper(attention_cell, mgc_prenets, lf0_prenets) concat_cell = ConcatOutputAndAttentionWrapper(prenet_cell) self._cell = concat_cell
Example #4
Source File: module.py From self-attention-tacotron with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self, cell, mgc_prenets: Tuple[PreNet], lf0_prenets: Tuple[PreNet], attention_mechanism, trainable=True, name=None, **kwargs): super(MgcLf0AttentionRNN, self).__init__(name=name, trainable=trainable, **kwargs) attention_cell = AttentionWrapper( cell, attention_mechanism, alignment_history=True, output_attention=False) prenet_cell = DecoderMgcLf0PreNetWrapper(attention_cell, mgc_prenets, lf0_prenets) concat_cell = ConcatOutputAndAttentionWrapper(prenet_cell) self._cell = concat_cell
Example #5
Source File: module.py From self-attention-tacotron with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self, cell, prenets: Tuple[PreNet], attention_mechanism1, attention_mechanism2, trainable=True, name=None, **kwargs): super(DualSourceAttentionRNN, self).__init__(name=name, trainable=trainable, **kwargs) attention_cell = AttentionWrapper( cell, [attention_mechanism1, attention_mechanism2], alignment_history=True, output_attention=False) prenet_cell = DecoderPreNetWrapper(attention_cell, prenets) concat_cell = ConcatOutputAndAttentionWrapper(prenet_cell) self._cell = concat_cell
Example #6
Source File: rnn_decoders.py From texar with Apache License 2.0 | 5 votes |
def _alignments_size(self): # Reimplementation of the alignments_size of each of # AttentionWrapper.attention_mechanisms. The original implementation # of `_BaseAttentionMechanism._alignments_size`: # # self._alignments_size = (self._keys.shape[1].value or # array_ops.shape(self._keys)[1]) # # can be `None` when the seq length of encoder outputs are priori # unknown. alignments_size = [] for am in self._cell._attention_mechanisms: az = (am._keys.shape[1].value or tf.shape(am._keys)[1:-1]) alignments_size.append(az) return self._cell._item_or_tuple(alignments_size)
Example #7
Source File: rnn_decoders.py From texar with Apache License 2.0 | 5 votes |
def _get_beam_search_cell(self, beam_width): """Returns the RNN cell for beam search decoding. """ with tf.variable_scope(self.variable_scope, reuse=True): attn_kwargs = copy.copy(self._attn_kwargs) memory = attn_kwargs['memory'] attn_kwargs['memory'] = tile_batch(memory, multiplier=beam_width) memory_seq_length = attn_kwargs['memory_sequence_length'] if memory_seq_length is not None: attn_kwargs['memory_sequence_length'] = tile_batch( memory_seq_length, beam_width) attn_modules = ['tensorflow.contrib.seq2seq', 'texar.tf.custom'] bs_attention_mechanism = utils.check_or_get_instance( self._hparams.attention.type, attn_kwargs, attn_modules, classtype=tf.contrib.seq2seq.AttentionMechanism) bs_attn_cell = AttentionWrapper( self._cell._cell, bs_attention_mechanism, cell_input_fn=self._cell_input_fn, **self._attn_cell_kwargs) self._beam_search_cell = bs_attn_cell return bs_attn_cell
Example #8
Source File: model_utils.py From language with Apache License 2.0 | 5 votes |
def __init__(self, attention_cell, cells, use_new_attention=False): """Creates a GNMTAttentionMultiCell. Args: attention_cell: An instance of AttentionWrapper. cells: A list of RNNCell wrapped with AttentionInputWrapper. use_new_attention: Whether to use the attention generated from current step bottom layer's output. Default is False. """ cells = [attention_cell] + cells self.use_new_attention = use_new_attention super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True)
Example #9
Source File: rnn_wrappers.py From tacotron2 with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self, cell, prenets: Tuple[PreNet], attention_mechanism, trainable=True, name=None, dtype=None, **kwargs): super(AttentionRNN, self).__init__(trainable=trainable, name=name, dtype=dtype, **kwargs) attention_cell = AttentionWrapper( cell, attention_mechanism, alignment_history=True, output_attention=False) # prenet -> attention prenet_cell = DecoderPreNetWrapper(attention_cell, prenets) # prenet -> attention -> concat concat_cell = ConcatOutputAndAttentionWrapper(prenet_cell) self._cell = concat_cell
Example #10
Source File: seq2seq_model.py From AmusingPythonCodes with MIT License | 5 votes |
def _create_decoder_cell(self): enc_outputs, enc_states, enc_seq_len = self.enc_outputs, self.enc_states, self.enc_seq_len batch_size = self.batch_size * self.cfg.beam_size if self.use_beam_search else self.batch_size with tf.variable_scope("attention"): if self.cfg.attention == "luong": # Luong attention mechanism attention_mechanism = LuongAttention(num_units=self.cfg.num_units, memory=enc_outputs, memory_sequence_length=enc_seq_len) else: # default using Bahdanau attention mechanism attention_mechanism = BahdanauAttention(num_units=self.cfg.num_units, memory=enc_outputs, memory_sequence_length=enc_seq_len) def cell_input_fn(inputs, attention): # define cell input function to keep input/output dimension same # reference: https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/AttentionWrapper if not self.cfg.use_attention_input_feeding: return inputs input_project = tf.layers.Dense(self.cfg.num_units, dtype=tf.float32, name='attn_input_feeding') return input_project(tf.concat([inputs, attention], axis=-1)) if self.cfg.top_attention: # apply attention mechanism only on the top decoder layer cells = [self._create_rnn_cell() for _ in range(self.cfg.num_layers)] cells[-1] = AttentionWrapper(cells[-1], attention_mechanism=attention_mechanism, name="Attention_Wrapper", attention_layer_size=self.cfg.num_units, initial_cell_state=enc_states[-1], cell_input_fn=cell_input_fn) initial_state = [state for state in enc_states] initial_state[-1] = cells[-1].zero_state(batch_size=batch_size, dtype=tf.float32) dec_init_states = tuple(initial_state) cells = MultiRNNCell(cells) else: cells = MultiRNNCell([self._create_rnn_cell() for _ in range(self.cfg.num_layers)]) cells = AttentionWrapper(cells, attention_mechanism=attention_mechanism, name="Attention_Wrapper", attention_layer_size=self.cfg.num_units, initial_cell_state=enc_states, cell_input_fn=cell_input_fn) dec_init_states = cells.zero_state(batch_size=batch_size, dtype=tf.float32).clone(cell_state=enc_states) return cells, dec_init_states
Example #11
Source File: _rnn.py From DeepChatModels with MIT License | 5 votes |
def zero_state(self, batch_size, dtype): with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: cell_state = self._cell.zero_state(batch_size, dtype) error_message = ( "zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and the requested batch size.") with tf.control_dependencies( [tf.assert_equal(batch_size, self._attention_mechanism.batch_size, message=error_message)]): cell_state = nest.map_structure( lambda s: tf.identity(s, name="checked_cell_state"), cell_state) alignment_history = () _zero_state_tensors = rnn_cell_impl._zero_state_tensors return AttentionWrapperState( cell_state=cell_state, time=tf.zeros([], dtype=tf.int32), attention=_zero_state_tensors(self._attention_size, batch_size, dtype), alignments=self._attention_mechanism.initial_alignments( batch_size, dtype), alignment_history=alignment_history)
Example #12
Source File: rnn_decoders.py From Counterfactual-StoryRW with MIT License | 5 votes |
def _alignments_size(self): # Reimplementation of the alignments_size of each of # AttentionWrapper.attention_mechanisms. The original implementation # of `_BaseAttentionMechanism._alignments_size`: # # self._alignments_size = (self._keys.shape[1].value or # array_ops.shape(self._keys)[1]) # # can be `None` when the seq length of encoder outputs are priori # unknown. alignments_size = [] for am in self._cell._attention_mechanisms: az = (am._keys.shape[1].value or tf.shape(am._keys)[1:-1]) alignments_size.append(az) return self._cell._item_or_tuple(alignments_size)
Example #13
Source File: rnn_decoders.py From Counterfactual-StoryRW with MIT License | 5 votes |
def _get_beam_search_cell(self, beam_width): """Returns the RNN cell for beam search decoding. """ with tf.variable_scope(self.variable_scope, reuse=True): attn_kwargs = copy.copy(self._attn_kwargs) memory = attn_kwargs['memory'] attn_kwargs['memory'] = tile_batch(memory, multiplier=beam_width) memory_seq_length = attn_kwargs['memory_sequence_length'] if memory_seq_length is not None: attn_kwargs['memory_sequence_length'] = tile_batch( memory_seq_length, beam_width) attn_modules = ['tensorflow.contrib.seq2seq', 'texar.custom'] bs_attention_mechanism = utils.check_or_get_instance( self._hparams.attention.type, attn_kwargs, attn_modules, classtype=tf.contrib.seq2seq.AttentionMechanism) bs_attn_cell = AttentionWrapper( self._cell._cell, bs_attention_mechanism, cell_input_fn=self._cell_input_fn, **self._attn_cell_kwargs) self._beam_search_cell = bs_attn_cell return bs_attn_cell
Example #14
Source File: decoder_bimodal.py From avsr-tf1 with GNU General Public License v3.0 | 4 votes |
def _build_decoder_beam_search(self): batch_size, _ = tf.unstack(tf.shape(self._labels)) attention_mechanisms, layer_sizes = self._create_attention_mechanisms(beam_search=True) decoder_initial_state_tiled = seq2seq.tile_batch( self._decoder_initial_state, multiplier=self._hparams.beam_width) if self._hparams.enable_attention is True: attention_cells = seq2seq.AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanisms, attention_layer_size=layer_sizes, initial_cell_state=decoder_initial_state_tiled, alignment_history=self._hparams.write_attention_alignment, output_attention=self._output_attention) initial_state = attention_cells.zero_state( dtype=self._hparams.dtype, batch_size=batch_size * self._hparams.beam_width) initial_state = initial_state.clone( cell_state=decoder_initial_state_tiled) cells = attention_cells else: cells = self._decoder_cells initial_state = decoder_initial_state_tiled self._decoder_inference = seq2seq.BeamSearchDecoder( cell=cells, embedding=self._embedding_matrix, start_tokens=array_ops.fill([batch_size], self._GO_ID), end_token=self._EOS_ID, initial_state=initial_state, beam_width=self._hparams.beam_width, output_layer=self._dense_layer, length_penalty_weight=0.5, ) outputs, states, lengths = seq2seq.dynamic_decode( self._decoder_inference, impute_finished=False, maximum_iterations=self._hparams.max_label_length, swap_memory=False) if self._hparams.write_attention_alignment is True: self.attention_summary = self._create_attention_alignments_summary(states) self.inference_outputs = outputs.beam_search_decoder_output self.inference_predicted_ids = outputs.predicted_ids[:, :, 0] # return the first beam self.inference_predicted_beam = outputs.predicted_ids self.beam_search_output = outputs.beam_search_decoder_output
Example #15
Source File: rnn_decoders.py From Counterfactual-StoryRW with MIT License | 4 votes |
def __init__(self, memory, memory_sequence_length=None, cell=None, cell_dropout_mode=None, vocab_size=None, output_layer=None, #attention_layer=None, # TODO(zhiting): only valid for tf>=1.0 cell_input_fn=None, hparams=None): RNNDecoderBase.__init__( self, cell, vocab_size, output_layer, cell_dropout_mode, hparams) attn_hparams = self._hparams['attention'] attn_kwargs = attn_hparams['kwargs'].todict() # Parse the 'probability_fn' argument if 'probability_fn' in attn_kwargs: prob_fn = attn_kwargs['probability_fn'] if prob_fn is not None and not callable(prob_fn): prob_fn = utils.get_function( prob_fn, ['tensorflow.nn', 'tensorflow.contrib.sparsemax', 'tensorflow.contrib.seq2seq']) attn_kwargs['probability_fn'] = prob_fn attn_kwargs.update({ "memory_sequence_length": memory_sequence_length, "memory": memory}) self._attn_kwargs = attn_kwargs attn_modules = ['tensorflow.contrib.seq2seq', 'texar.custom'] # Use variable_scope to ensure all trainable variables created in # the attention mechanism are collected with tf.variable_scope(self.variable_scope): attention_mechanism = utils.check_or_get_instance( attn_hparams["type"], attn_kwargs, attn_modules, classtype=tf.contrib.seq2seq.AttentionMechanism) self._attn_cell_kwargs = { "attention_layer_size": attn_hparams["attention_layer_size"], "alignment_history": attn_hparams["alignment_history"], "output_attention": attn_hparams["output_attention"], } self._cell_input_fn = cell_input_fn # Use variable_scope to ensure all trainable variables created in # AttentionWrapper are collected with tf.variable_scope(self.variable_scope): #if attention_layer is not None: # self._attn_cell_kwargs["attention_layer_size"] = None attn_cell = AttentionWrapper( self._cell, attention_mechanism, cell_input_fn=self._cell_input_fn, #attention_layer=attention_layer, **self._attn_cell_kwargs) self._cell = attn_cell
Example #16
Source File: decoder_bimodal.py From avsr-tf1 with GNU General Public License v3.0 | 4 votes |
def _build_decoder_greedy(self): batch_size, _ = tf.unstack(tf.shape(self._labels)) self._helper_greedy = seq2seq.GreedyEmbeddingHelper( embedding=self._embedding_matrix, start_tokens=tf.tile([self._GO_ID], [batch_size]), end_token=self._EOS_ID) if self._hparams.enable_attention is True: attention_mechanisms, layer_sizes = self._create_attention_mechanisms() attention_cells = seq2seq.AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanisms, attention_layer_size=layer_sizes, initial_cell_state=self._decoder_initial_state, alignment_history=self._hparams.write_attention_alignment, output_attention=self._output_attention ) attn_zero = attention_cells.zero_state( dtype=self._hparams.dtype, batch_size=batch_size ) initial_state = attn_zero.clone( cell_state=self._decoder_initial_state ) cells = attention_cells else: cells = self._decoder_cells initial_state = self._decoder_initial_state self._decoder_inference = seq2seq.BasicDecoder( cell=cells, helper=self._helper_greedy, initial_state=initial_state, output_layer=self._dense_layer) outputs, states, lengths = seq2seq.dynamic_decode( self._decoder_inference, impute_finished=True, swap_memory=False, maximum_iterations=self._hparams.max_label_length) # self._result = outputs, states, lengths self.inference_outputs = outputs.rnn_output self.inference_predicted_ids = outputs.sample_id if self._hparams.write_attention_alignment is True: self.attention_summary = self._create_attention_alignments_summary(states)
Example #17
Source File: decoder_bimodal.py From avsr-tf1 with GNU General Public License v3.0 | 4 votes |
def _build_decoder_train(self): self._labels_embedded = tf.nn.embedding_lookup(self._embedding_matrix, self._labels_padded_GO) self._helper_train = seq2seq.ScheduledEmbeddingTrainingHelper( inputs=self._labels_embedded, sequence_length=self._labels_len, embedding=self._embedding_matrix, sampling_probability=self._sampling_probability_outputs, ) if self._hparams.enable_attention is True: attention_mechanisms, layer_sizes = self._create_attention_mechanisms() attention_cells = seq2seq.AttentionWrapper( cell=self._decoder_cells, attention_mechanism=attention_mechanisms, attention_layer_size=layer_sizes, initial_cell_state=self._decoder_initial_state, alignment_history=False, output_attention=self._output_attention, ) batch_size, _ = tf.unstack(tf.shape(self._labels)) attn_zero = attention_cells.zero_state( dtype=self._hparams.dtype, batch_size=batch_size ) initial_state = attn_zero.clone( cell_state=self._decoder_initial_state ) cells = attention_cells else: cells = self._decoder_cells initial_state = self._decoder_initial_state self._decoder_train = seq2seq.BasicDecoder( cell=cells, helper=self._helper_train, initial_state=initial_state, output_layer=self._dense_layer, ) self._basic_decoder_train_outputs, self._final_states, self._final_seq_lens = seq2seq.dynamic_decode( self._decoder_train, output_time_major=False, impute_finished=True, swap_memory=False, ) self._logits = self._basic_decoder_train_outputs.rnn_output
Example #18
Source File: model_utils.py From language with Apache License 2.0 | 4 votes |
def create_rnn_cell(unit_type, num_units, num_layers, num_residual_layers, forget_bias, dropout, mode, attention_mechanism=None, attention_num_heads=1, attention_layer_size=None, output_attention=False, single_cell_fn=None, trainable=True): """Returns an instance of an RNN cell. Args: unit_type: A string that specifies the type of the recurrent unit. Must be one of {"lstm", "gru", "lstm_norm", "nas"}. num_units: An integer for the numner of units per layer. num_layers: An integer for the number of recurrent layers. num_residual_layers: An integer for the number of residual layers. forget_bias: A float for the forget bias in LSTM cells. dropout: A float for the recurrent dropout rate. mode: TRAIN | EVAL | PREDICT attention_mechanism: An instance of tf.contrib.seq2seq.AttentionMechanism. attention_num_heads: An integer for the number of attention heads. attention_layer_size: Optional integer for the size of the attention layer. output_attention: A boolean indicating whether RNN cell outputs attention. single_cell_fn: A function for building a single RNN cell. trainable: A boolean indicating whether the cell weights are trainable. Returns: An RNNCell instance. """ cell_list = _cell_list( unit_type=unit_type, num_units=num_units, num_layers=num_layers, forget_bias=forget_bias, dropout=dropout, mode=mode, num_residual_layers=num_residual_layers, single_cell_fn=single_cell_fn, trainable=trainable) if len(cell_list) == 1: # Single layer. cell = cell_list[0] else: # Multiple layers. cell = contrib_rnn.MultiRNNCell(cell_list) # Wrap with attention, if necessary. if attention_mechanism is not None: cell = contrib_seq2seq.AttentionWrapper( cell, [attention_mechanism] * attention_num_heads, attention_layer_size=[attention_layer_size] * attention_num_heads, alignment_history=False, output_attention=output_attention, name="attention") return cell
Example #19
Source File: model_utils.py From language with Apache License 2.0 | 4 votes |
def create_gnmt_rnn_cell(unit_type, num_units, num_layers, num_residual_layers, forget_bias, dropout, mode, attention_mechanism, attention_num_heads=1, attention_layer_size=None, output_attention=False, single_cell_fn=None): """Returns an instance of an GNMT-style RNN cell. Args: unit_type: A string that specifies the type of the recurrent unit. Must be one of {"lstm", "gru", "lstm_norm", "nas"}. num_units: An integer for the numner of units per layer. num_layers: An integer for the number of recurrent layers. num_residual_layers: An integer for the number of residual layers. forget_bias: A float for the forget bias in LSTM cells. dropout: A float for the recurrent dropout rate. mode: TRAIN | EVAL | PREDICT attention_mechanism: An instance of tf.contrib.seq2seq.AttentionMechanism. attention_num_heads: An integer for the number of attention heads. attention_layer_size: Optional integer for the size of the attention layer. output_attention: A boolean indicating whether RNN cell outputs attention. single_cell_fn: A function for building a single RNN cell. Returns: An RNNCell instance. """ cell_list = _cell_list( unit_type=unit_type, num_units=num_units, num_layers=num_layers, forget_bias=forget_bias, dropout=dropout, mode=mode, num_residual_layers=num_residual_layers, single_cell_fn=single_cell_fn, residual_fn=gnmt_residual_fn) if attention_num_heads > 1: attention_mechanism = [attention_mechanism] * attention_num_heads attention_layer_size = [attention_layer_size] * attention_num_heads # Only wrap the bottom layer with the attention mechanism. attention_cell = cell_list.pop(0) attention_cell = contrib_seq2seq.AttentionWrapper( attention_cell, attention_mechanism, attention_layer_size=attention_layer_size, alignment_history=False, output_attention=output_attention, name="gnmt_attention") cell = GNMTAttentionMultiCell( attention_cell, cell_list, use_new_attention=True) return cell
Example #20
Source File: rnn_decoders.py From texar with Apache License 2.0 | 4 votes |
def __init__(self, memory, memory_sequence_length=None, cell=None, cell_dropout_mode=None, vocab_size=None, output_layer=None, # attention_layer=None, # TODO(zhiting): only valid for tf>=1.0 cell_input_fn=None, hparams=None): RNNDecoderBase.__init__( self, cell, vocab_size, output_layer, cell_dropout_mode, hparams) attn_hparams = self._hparams['attention'] attn_kwargs = attn_hparams['kwargs'].todict() # Parse the 'probability_fn' argument if 'probability_fn' in attn_kwargs: prob_fn = attn_kwargs['probability_fn'] if prob_fn is not None and not callable(prob_fn): prob_fn = utils.get_function( prob_fn, ['tensorflow.nn', 'tensorflow.contrib.sparsemax', 'tensorflow.contrib.seq2seq']) attn_kwargs['probability_fn'] = prob_fn attn_kwargs.update({ "memory_sequence_length": memory_sequence_length, "memory": memory}) self._attn_kwargs = attn_kwargs attn_modules = ['tensorflow.contrib.seq2seq', 'texar.tf.custom'] # Use variable_scope to ensure all trainable variables created in # the attention mechanism are collected with tf.variable_scope(self.variable_scope): attention_mechanism = utils.check_or_get_instance( attn_hparams["type"], attn_kwargs, attn_modules, classtype=tf.contrib.seq2seq.AttentionMechanism) self._attn_cell_kwargs = { "attention_layer_size": attn_hparams["attention_layer_size"], "alignment_history": attn_hparams["alignment_history"], "output_attention": attn_hparams["output_attention"], } self._cell_input_fn = cell_input_fn # Use variable_scope to ensure all trainable variables created in # AttentionWrapper are collected with tf.variable_scope(self.variable_scope): # if attention_layer is not None: # self._attn_cell_kwargs["attention_layer_size"] = None attn_cell = AttentionWrapper( self._cell, attention_mechanism, cell_input_fn=self._cell_input_fn, # attention_layer=attention_layer, **self._attn_cell_kwargs) self._cell = attn_cell
Example #21
Source File: attention.py From avsr-tf1 with GNU General Public License v3.0 | 4 votes |
def add_attention( cells, attention_types, num_units, memory, memory_len, mode, batch_size, dtype, beam_search=False, beam_width=None, initial_state=None, write_attention_alignment=False, fusion_type='linear_fusion', ): r""" Wraps the decoder_cells with an AttentionWrapper Args: cells: instances of `RNNCell` beam_search: `bool` flag for beam search decoders batch_size: `Tensor` containing the batch size. Necessary to the initialisation of the initial state Returns: attention_cells: the Attention wrapped decoder cells initial_state: a proper initial state to be used with the returned cells """ attention_mechanisms, attention_layers, attention_layer_sizes, output_attention = create_attention_mechanisms( beam_search=beam_search, beam_width=beam_width, memory=memory, memory_len=memory_len, num_units=num_units, attention_types=attention_types, fusion_type=fusion_type, mode=mode, dtype=dtype) if beam_search is True: initial_state= seq2seq.tile_batch( initial_state, multiplier=beam_width) attention_cells = seq2seq.AttentionWrapper( cell=cells, attention_mechanism=attention_mechanisms, attention_layer_size=attention_layer_sizes, # initial_cell_state=decoder_initial_state, alignment_history=write_attention_alignment, output_attention=output_attention, attention_layer=attention_layers, ) attn_zero = attention_cells.zero_state( dtype=dtype, batch_size=batch_size * beam_width if beam_search is True else batch_size) if initial_state is not None: initial_state = attn_zero.clone( cell_state=initial_state) return attention_cells, initial_state
Example #22
Source File: seq2seq.py From retrosynthesis_planner with GNU General Public License v3.0 | 4 votes |
def _make_decoder(self, encoder_outputs, encoder_final_state, beam_search=False, reuse=False): """Create decoder""" with tf.variable_scope('decode', reuse=reuse): # Create decoder cells cells = [self._make_cell() for _ in range(self.depth)] if beam_search: # Tile inputs as needed for beam search encoder_outputs = seq2seq.tile_batch( encoder_outputs, multiplier=self.beam_width) encoder_final_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, multiplier=self.beam_width), encoder_final_state) sequence_length = seq2seq.tile_batch( self.sequence_length, multiplier=self.beam_width) else: sequence_length = self.sequence_length # Prepare attention mechanism; # add only to last cell attention_mechanism = seq2seq.LuongAttention( num_units=self.hidden_size, memory=encoder_outputs, memory_sequence_length=sequence_length, name='attn') cells[-1] = seq2seq.AttentionWrapper( cells[-1], attention_mechanism, attention_layer_size=self.hidden_size, initial_cell_state=encoder_final_state[-1], cell_input_fn=lambda inp, attn: tf.layers.dense(tf.concat([inp, attn], -1), self.hidden_size), name='attnwrap' ) # Copy encoder final state as decoder initial state decoder_initial_state = [s for s in encoder_final_state] # Set last initial state to be AttentionWrapperState batch_size = self.batch_size if beam_search: batch_size = self.batch_size * self.beam_width decoder_initial_state[-1] = cells[-1].zero_state( dtype=tf.float32, batch_size=batch_size) # Wrap up the cells cell = rnn.MultiRNNCell(cells) # Return initial state as a tuple # (required by tensorflow) return cell, tuple(decoder_initial_state)
Example #23
Source File: _rnn.py From DeepChatModels with MIT License | 4 votes |
def __init__(self, cell, attention_mechanism, initial_cell_state=None, name=None): """Construct the wrapper. Main tweak is creating the attention_layer with a tanh activation (Luong's choice) as opposed to linear (TensorFlow's choice). Also, since I am sticking with Luong's approach, parameters that are in the constructor of TensorFlow's AttentionWrapper have been removed, and the corresponding values are set to how Luong's paper defined them. Args: cell: instance of the Cell class above. attention_mechanism: instance of tf AttentionMechanism. initial_cell_state: The initial state value to use for the cell when the user calls `zero_state()`. name: Name to use when creating ops. """ super(SimpleAttentionWrapper, self).__init__(name=name) # Assume that 'cell' is an instance of the custom 'Cell' class above. self._base_cell = cell._base_cell self._num_layers = cell._num_layers self._state_size = cell._state_size self._attention_size = attention_mechanism.values.get_shape()[-1].value self._attention_layer = layers_core.Dense(self._attention_size, activation=tf.nn.tanh, name="attention_layer", use_bias=False) self._cell = cell self._attention_mechanism = attention_mechanism with tf.name_scope(name, "AttentionWrapperInit"): if initial_cell_state is None: self._initial_cell_state = None else: final_state_tensor = nest.flatten(initial_cell_state)[-1] state_batch_size = ( final_state_tensor.shape[0].value or tf.shape(final_state_tensor)[0]) error_message = ( "Constructor AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and initial_cell_state.") with tf.control_dependencies( [tf.assert_equal(state_batch_size, self._attention_mechanism.batch_size, message=error_message)]): self._initial_cell_state = nest.map_structure( lambda s: tf.identity(s, name="check_initial_cell_state"), initial_cell_state)