Python tensorflow.contrib.seq2seq.AttentionWrapper() Examples

The following are 23 code examples of tensorflow.contrib.seq2seq.AttentionWrapper(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.contrib.seq2seq , or try the search function .
Example #1
Source File: seq2seq_decoder_estimator.py    From icecaps with MIT License 6 votes vote down vote up
def build_attention_wrapper(self, final_cell):
        self.feedforward_inputs = tf.cond(
            self.beam_search_decoding, lambda: seq2seq.tile_batch(
                self.features["inputs"], multiplier=self.hparams.beam_width),
            lambda: self.features["inputs"])
        self.feedforward_inputs_length = tf.cond(
            self.beam_search_decoding, lambda: seq2seq.tile_batch(
                self.features["length"], multiplier=self.hparams.beam_width),
            lambda: self.features["length"])
        attention_mechanism = self.build_attention_mechanism()
        return AttentionWrapper(
            cell=final_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=self.hparams.hidden_units,
            cell_input_fn=self._attention_input_feeding,
            initial_cell_state=self.initial_state[-1] if self.hparams.depth > 1 else self.initial_state) 
Example #2
Source File: attention_predictor.py    From aster with MIT License 5 votes vote down vote up
def _build_decoder_cell(self, feature_maps):
    attention_mechanism = self._build_attention_mechanism(feature_maps)
    wrapper_class = seq2seq.AttentionWrapper if not self._sync else sync_attention_wrapper.SyncAttentionWrapper
    attention_cell = wrapper_class(
      self._rnn_cell,
      attention_mechanism,
      output_attention=False)
    if not self._lm_rnn_cell:
      decoder_cell = attention_cell
    else:
      decoder_cell = ConcatOutputMultiRNNCell([attention_cell, self._lm_rnn_cell])

    return decoder_cell 
Example #3
Source File: module.py    From self-attention-tacotron with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, cell, mgc_prenets: Tuple[PreNet], lf0_prenets: Tuple[PreNet],
                 attention_mechanism1,
                 attention_mechanism2,
                 trainable=True, name=None, **kwargs):
        super(DualSourceMgcLf0AttentionRNN, self).__init__(name=name, trainable=trainable, **kwargs)
        attention_cell = AttentionWrapper(
            cell,
            [attention_mechanism1, attention_mechanism2],
            alignment_history=True,
            output_attention=False)
        prenet_cell = DecoderMgcLf0PreNetWrapper(attention_cell, mgc_prenets, lf0_prenets)
        concat_cell = ConcatOutputAndAttentionWrapper(prenet_cell)
        self._cell = concat_cell 
Example #4
Source File: module.py    From self-attention-tacotron with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, cell, mgc_prenets: Tuple[PreNet], lf0_prenets: Tuple[PreNet],
                 attention_mechanism,
                 trainable=True, name=None, **kwargs):
        super(MgcLf0AttentionRNN, self).__init__(name=name, trainable=trainable, **kwargs)
        attention_cell = AttentionWrapper(
            cell,
            attention_mechanism,
            alignment_history=True,
            output_attention=False)
        prenet_cell = DecoderMgcLf0PreNetWrapper(attention_cell, mgc_prenets, lf0_prenets)
        concat_cell = ConcatOutputAndAttentionWrapper(prenet_cell)
        self._cell = concat_cell 
Example #5
Source File: module.py    From self-attention-tacotron with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, cell, prenets: Tuple[PreNet],
                 attention_mechanism1,
                 attention_mechanism2,
                 trainable=True, name=None, **kwargs):
        super(DualSourceAttentionRNN, self).__init__(name=name, trainable=trainable, **kwargs)
        attention_cell = AttentionWrapper(
            cell,
            [attention_mechanism1, attention_mechanism2],
            alignment_history=True,
            output_attention=False)
        prenet_cell = DecoderPreNetWrapper(attention_cell, prenets)
        concat_cell = ConcatOutputAndAttentionWrapper(prenet_cell)
        self._cell = concat_cell 
Example #6
Source File: rnn_decoders.py    From texar with Apache License 2.0 5 votes vote down vote up
def _alignments_size(self):
        # Reimplementation of the alignments_size of each of
        # AttentionWrapper.attention_mechanisms. The original implementation
        # of `_BaseAttentionMechanism._alignments_size`:
        #
        #    self._alignments_size = (self._keys.shape[1].value or
        #                       array_ops.shape(self._keys)[1])
        #
        # can be `None` when the seq length of encoder outputs are priori
        # unknown.
        alignments_size = []
        for am in self._cell._attention_mechanisms:
            az = (am._keys.shape[1].value or tf.shape(am._keys)[1:-1])
            alignments_size.append(az)
        return self._cell._item_or_tuple(alignments_size) 
Example #7
Source File: rnn_decoders.py    From texar with Apache License 2.0 5 votes vote down vote up
def _get_beam_search_cell(self, beam_width):
        """Returns the RNN cell for beam search decoding.
        """
        with tf.variable_scope(self.variable_scope, reuse=True):
            attn_kwargs = copy.copy(self._attn_kwargs)

            memory = attn_kwargs['memory']
            attn_kwargs['memory'] = tile_batch(memory, multiplier=beam_width)

            memory_seq_length = attn_kwargs['memory_sequence_length']
            if memory_seq_length is not None:
                attn_kwargs['memory_sequence_length'] = tile_batch(
                    memory_seq_length, beam_width)

            attn_modules = ['tensorflow.contrib.seq2seq', 'texar.tf.custom']
            bs_attention_mechanism = utils.check_or_get_instance(
                self._hparams.attention.type, attn_kwargs, attn_modules,
                classtype=tf.contrib.seq2seq.AttentionMechanism)

            bs_attn_cell = AttentionWrapper(
                self._cell._cell,
                bs_attention_mechanism,
                cell_input_fn=self._cell_input_fn,
                **self._attn_cell_kwargs)

            self._beam_search_cell = bs_attn_cell

            return bs_attn_cell 
Example #8
Source File: model_utils.py    From language with Apache License 2.0 5 votes vote down vote up
def __init__(self, attention_cell, cells, use_new_attention=False):
    """Creates a GNMTAttentionMultiCell.

    Args:
      attention_cell: An instance of AttentionWrapper.
      cells: A list of RNNCell wrapped with AttentionInputWrapper.
      use_new_attention: Whether to use the attention generated from current
        step bottom layer's output. Default is False.
    """
    cells = [attention_cell] + cells
    self.use_new_attention = use_new_attention
    super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True) 
Example #9
Source File: rnn_wrappers.py    From tacotron2 with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, cell, prenets: Tuple[PreNet],
                 attention_mechanism,
                 trainable=True, name=None, dtype=None, **kwargs):
        super(AttentionRNN, self).__init__(trainable=trainable, name=name, dtype=dtype, **kwargs)
        attention_cell = AttentionWrapper(
            cell,
            attention_mechanism,
            alignment_history=True,
            output_attention=False)
        # prenet -> attention
        prenet_cell = DecoderPreNetWrapper(attention_cell, prenets)
        # prenet -> attention -> concat
        concat_cell = ConcatOutputAndAttentionWrapper(prenet_cell)
        self._cell = concat_cell 
Example #10
Source File: seq2seq_model.py    From AmusingPythonCodes with MIT License 5 votes vote down vote up
def _create_decoder_cell(self):
        enc_outputs, enc_states, enc_seq_len = self.enc_outputs, self.enc_states, self.enc_seq_len
        batch_size = self.batch_size * self.cfg.beam_size if self.use_beam_search else self.batch_size
        with tf.variable_scope("attention"):
            if self.cfg.attention == "luong":  # Luong attention mechanism
                attention_mechanism = LuongAttention(num_units=self.cfg.num_units, memory=enc_outputs,
                                                     memory_sequence_length=enc_seq_len)
            else:  # default using Bahdanau attention mechanism
                attention_mechanism = BahdanauAttention(num_units=self.cfg.num_units, memory=enc_outputs,
                                                        memory_sequence_length=enc_seq_len)

        def cell_input_fn(inputs, attention):  # define cell input function to keep input/output dimension same
            # reference: https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/AttentionWrapper
            if not self.cfg.use_attention_input_feeding:
                return inputs
            input_project = tf.layers.Dense(self.cfg.num_units, dtype=tf.float32, name='attn_input_feeding')
            return input_project(tf.concat([inputs, attention], axis=-1))

        if self.cfg.top_attention:  # apply attention mechanism only on the top decoder layer
            cells = [self._create_rnn_cell() for _ in range(self.cfg.num_layers)]
            cells[-1] = AttentionWrapper(cells[-1], attention_mechanism=attention_mechanism, name="Attention_Wrapper",
                                         attention_layer_size=self.cfg.num_units, initial_cell_state=enc_states[-1],
                                         cell_input_fn=cell_input_fn)
            initial_state = [state for state in enc_states]
            initial_state[-1] = cells[-1].zero_state(batch_size=batch_size, dtype=tf.float32)
            dec_init_states = tuple(initial_state)
            cells = MultiRNNCell(cells)
        else:
            cells = MultiRNNCell([self._create_rnn_cell() for _ in range(self.cfg.num_layers)])
            cells = AttentionWrapper(cells, attention_mechanism=attention_mechanism, name="Attention_Wrapper",
                                     attention_layer_size=self.cfg.num_units, initial_cell_state=enc_states,
                                     cell_input_fn=cell_input_fn)
            dec_init_states = cells.zero_state(batch_size=batch_size, dtype=tf.float32).clone(cell_state=enc_states)
        return cells, dec_init_states 
Example #11
Source File: _rnn.py    From DeepChatModels with MIT License 5 votes vote down vote up
def zero_state(self, batch_size, dtype):
        with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
            if self._initial_cell_state is not None:
                cell_state = self._initial_cell_state
            else:
                cell_state = self._cell.zero_state(batch_size, dtype)
            error_message = (
                "zero_state of AttentionWrapper %s: " % self._base_name +
                "Non-matching batch sizes between the memory "
                "(encoder output) and the requested batch size.")
            with tf.control_dependencies(
                [tf.assert_equal(batch_size,
                    self._attention_mechanism.batch_size,
                    message=error_message)]):
                cell_state = nest.map_structure(
                    lambda s: tf.identity(s, name="checked_cell_state"),
                    cell_state)
            alignment_history = ()

            _zero_state_tensors = rnn_cell_impl._zero_state_tensors
            return AttentionWrapperState(
                cell_state=cell_state,
                time=tf.zeros([], dtype=tf.int32),
                attention=_zero_state_tensors(self._attention_size, batch_size,
                dtype),
                alignments=self._attention_mechanism.initial_alignments(
                    batch_size, dtype),
                alignment_history=alignment_history) 
Example #12
Source File: rnn_decoders.py    From Counterfactual-StoryRW with MIT License 5 votes vote down vote up
def _alignments_size(self):
        # Reimplementation of the alignments_size of each of
        # AttentionWrapper.attention_mechanisms. The original implementation
        # of `_BaseAttentionMechanism._alignments_size`:
        #
        #    self._alignments_size = (self._keys.shape[1].value or
        #                       array_ops.shape(self._keys)[1])
        #
        # can be `None` when the seq length of encoder outputs are priori
        # unknown.
        alignments_size = []
        for am in self._cell._attention_mechanisms:
            az = (am._keys.shape[1].value or tf.shape(am._keys)[1:-1])
            alignments_size.append(az)
        return self._cell._item_or_tuple(alignments_size) 
Example #13
Source File: rnn_decoders.py    From Counterfactual-StoryRW with MIT License 5 votes vote down vote up
def _get_beam_search_cell(self, beam_width):
        """Returns the RNN cell for beam search decoding.
        """
        with tf.variable_scope(self.variable_scope, reuse=True):
            attn_kwargs = copy.copy(self._attn_kwargs)

            memory = attn_kwargs['memory']
            attn_kwargs['memory'] = tile_batch(memory, multiplier=beam_width)

            memory_seq_length = attn_kwargs['memory_sequence_length']
            if memory_seq_length is not None:
                attn_kwargs['memory_sequence_length'] = tile_batch(
                    memory_seq_length, beam_width)

            attn_modules = ['tensorflow.contrib.seq2seq', 'texar.custom']
            bs_attention_mechanism = utils.check_or_get_instance(
                self._hparams.attention.type, attn_kwargs, attn_modules,
                classtype=tf.contrib.seq2seq.AttentionMechanism)

            bs_attn_cell = AttentionWrapper(
                self._cell._cell,
                bs_attention_mechanism,
                cell_input_fn=self._cell_input_fn,
                **self._attn_cell_kwargs)

            self._beam_search_cell = bs_attn_cell

            return bs_attn_cell 
Example #14
Source File: decoder_bimodal.py    From avsr-tf1 with GNU General Public License v3.0 4 votes vote down vote up
def _build_decoder_beam_search(self):

        batch_size, _ = tf.unstack(tf.shape(self._labels))

        attention_mechanisms, layer_sizes = self._create_attention_mechanisms(beam_search=True)

        decoder_initial_state_tiled = seq2seq.tile_batch(
            self._decoder_initial_state, multiplier=self._hparams.beam_width)

        if self._hparams.enable_attention is True:

            attention_cells = seq2seq.AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanisms,
                attention_layer_size=layer_sizes,
                initial_cell_state=decoder_initial_state_tiled,
                alignment_history=self._hparams.write_attention_alignment,
                output_attention=self._output_attention)

            initial_state = attention_cells.zero_state(
                dtype=self._hparams.dtype, batch_size=batch_size * self._hparams.beam_width)

            initial_state = initial_state.clone(
                cell_state=decoder_initial_state_tiled)

            cells = attention_cells
        else:
            cells = self._decoder_cells
            initial_state = decoder_initial_state_tiled

        self._decoder_inference = seq2seq.BeamSearchDecoder(
            cell=cells,
            embedding=self._embedding_matrix,
            start_tokens=array_ops.fill([batch_size], self._GO_ID),
            end_token=self._EOS_ID,
            initial_state=initial_state,
            beam_width=self._hparams.beam_width,
            output_layer=self._dense_layer,
            length_penalty_weight=0.5,
        )

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=False,
            maximum_iterations=self._hparams.max_label_length,
            swap_memory=False)

        if self._hparams.write_attention_alignment is True:
            self.attention_summary = self._create_attention_alignments_summary(states)

        self.inference_outputs = outputs.beam_search_decoder_output
        self.inference_predicted_ids = outputs.predicted_ids[:, :, 0]  # return the first beam
        self.inference_predicted_beam = outputs.predicted_ids
        self.beam_search_output = outputs.beam_search_decoder_output 
Example #15
Source File: rnn_decoders.py    From Counterfactual-StoryRW with MIT License 4 votes vote down vote up
def __init__(self,
                 memory,
                 memory_sequence_length=None,
                 cell=None,
                 cell_dropout_mode=None,
                 vocab_size=None,
                 output_layer=None,
                 #attention_layer=None, # TODO(zhiting): only valid for tf>=1.0
                 cell_input_fn=None,
                 hparams=None):
        RNNDecoderBase.__init__(
            self, cell, vocab_size, output_layer, cell_dropout_mode, hparams)

        attn_hparams = self._hparams['attention']
        attn_kwargs = attn_hparams['kwargs'].todict()

        # Parse the 'probability_fn' argument
        if 'probability_fn' in attn_kwargs:
            prob_fn = attn_kwargs['probability_fn']
            if prob_fn is not None and not callable(prob_fn):
                prob_fn = utils.get_function(
                    prob_fn,
                    ['tensorflow.nn', 'tensorflow.contrib.sparsemax',
                     'tensorflow.contrib.seq2seq'])
            attn_kwargs['probability_fn'] = prob_fn

        attn_kwargs.update({
            "memory_sequence_length": memory_sequence_length,
            "memory": memory})
        self._attn_kwargs = attn_kwargs
        attn_modules = ['tensorflow.contrib.seq2seq', 'texar.custom']
        # Use variable_scope to ensure all trainable variables created in
        # the attention mechanism are collected
        with tf.variable_scope(self.variable_scope):
            attention_mechanism = utils.check_or_get_instance(
                attn_hparams["type"], attn_kwargs, attn_modules,
                classtype=tf.contrib.seq2seq.AttentionMechanism)

        self._attn_cell_kwargs = {
            "attention_layer_size": attn_hparams["attention_layer_size"],
            "alignment_history": attn_hparams["alignment_history"],
            "output_attention": attn_hparams["output_attention"],
        }
        self._cell_input_fn = cell_input_fn
        # Use variable_scope to ensure all trainable variables created in
        # AttentionWrapper are collected
        with tf.variable_scope(self.variable_scope):
            #if attention_layer is not None:
            #    self._attn_cell_kwargs["attention_layer_size"] = None
            attn_cell = AttentionWrapper(
                self._cell,
                attention_mechanism,
                cell_input_fn=self._cell_input_fn,
                #attention_layer=attention_layer,
                **self._attn_cell_kwargs)
            self._cell = attn_cell 
Example #16
Source File: decoder_bimodal.py    From avsr-tf1 with GNU General Public License v3.0 4 votes vote down vote up
def _build_decoder_greedy(self):

        batch_size, _ = tf.unstack(tf.shape(self._labels))
        self._helper_greedy = seq2seq.GreedyEmbeddingHelper(
            embedding=self._embedding_matrix,
            start_tokens=tf.tile([self._GO_ID], [batch_size]),
            end_token=self._EOS_ID)

        if self._hparams.enable_attention is True:
            attention_mechanisms, layer_sizes = self._create_attention_mechanisms()

            attention_cells = seq2seq.AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanisms,
                attention_layer_size=layer_sizes,
                initial_cell_state=self._decoder_initial_state,
                alignment_history=self._hparams.write_attention_alignment,
                output_attention=self._output_attention
            )
            attn_zero = attention_cells.zero_state(
                dtype=self._hparams.dtype, batch_size=batch_size
            )
            initial_state = attn_zero.clone(
                cell_state=self._decoder_initial_state
            )
            cells = attention_cells
        else:
            cells = self._decoder_cells
            initial_state = self._decoder_initial_state

        self._decoder_inference = seq2seq.BasicDecoder(
            cell=cells,
            helper=self._helper_greedy,
            initial_state=initial_state,
            output_layer=self._dense_layer)

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=True,
            swap_memory=False,
            maximum_iterations=self._hparams.max_label_length)

        # self._result = outputs, states, lengths
        self.inference_outputs = outputs.rnn_output
        self.inference_predicted_ids = outputs.sample_id

        if self._hparams.write_attention_alignment is True:
            self.attention_summary = self._create_attention_alignments_summary(states) 
Example #17
Source File: decoder_bimodal.py    From avsr-tf1 with GNU General Public License v3.0 4 votes vote down vote up
def _build_decoder_train(self):

        self._labels_embedded = tf.nn.embedding_lookup(self._embedding_matrix, self._labels_padded_GO)

        self._helper_train = seq2seq.ScheduledEmbeddingTrainingHelper(
            inputs=self._labels_embedded,
            sequence_length=self._labels_len,
            embedding=self._embedding_matrix,
            sampling_probability=self._sampling_probability_outputs,
        )

        if self._hparams.enable_attention is True:
            attention_mechanisms, layer_sizes = self._create_attention_mechanisms()

            attention_cells = seq2seq.AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanisms,
                attention_layer_size=layer_sizes,
                initial_cell_state=self._decoder_initial_state,
                alignment_history=False,
                output_attention=self._output_attention,
            )
            batch_size, _ = tf.unstack(tf.shape(self._labels))

            attn_zero = attention_cells.zero_state(
                dtype=self._hparams.dtype, batch_size=batch_size
            )
            initial_state = attn_zero.clone(
                cell_state=self._decoder_initial_state
            )

            cells = attention_cells
        else:
            cells = self._decoder_cells
            initial_state = self._decoder_initial_state

        self._decoder_train = seq2seq.BasicDecoder(
            cell=cells,
            helper=self._helper_train,
            initial_state=initial_state,
            output_layer=self._dense_layer,
        )

        self._basic_decoder_train_outputs, self._final_states, self._final_seq_lens = seq2seq.dynamic_decode(
            self._decoder_train,
            output_time_major=False,
            impute_finished=True,
            swap_memory=False,
        )

        self._logits = self._basic_decoder_train_outputs.rnn_output 
Example #18
Source File: model_utils.py    From language with Apache License 2.0 4 votes vote down vote up
def create_rnn_cell(unit_type,
                    num_units,
                    num_layers,
                    num_residual_layers,
                    forget_bias,
                    dropout,
                    mode,
                    attention_mechanism=None,
                    attention_num_heads=1,
                    attention_layer_size=None,
                    output_attention=False,
                    single_cell_fn=None,
                    trainable=True):
  """Returns an instance of an RNN cell.

  Args:
    unit_type: A string that specifies the type of the recurrent unit. Must be
      one of {"lstm", "gru", "lstm_norm", "nas"}.
    num_units: An integer for the numner of units per layer.
    num_layers: An integer for the number of recurrent layers.
    num_residual_layers: An integer for the number of residual layers.
    forget_bias: A float for the forget bias in LSTM cells.
    dropout: A float for the recurrent dropout rate.
    mode: TRAIN | EVAL | PREDICT
    attention_mechanism: An instance of tf.contrib.seq2seq.AttentionMechanism.
    attention_num_heads: An integer for the number of attention heads.
    attention_layer_size: Optional integer for the size of the attention layer.
    output_attention: A boolean indicating whether RNN cell outputs attention.
    single_cell_fn: A function for building a single RNN cell.
    trainable: A boolean indicating whether the cell weights are trainable.

  Returns:
    An RNNCell instance.
  """
  cell_list = _cell_list(
      unit_type=unit_type,
      num_units=num_units,
      num_layers=num_layers,
      forget_bias=forget_bias,
      dropout=dropout,
      mode=mode,
      num_residual_layers=num_residual_layers,
      single_cell_fn=single_cell_fn,
      trainable=trainable)

  if len(cell_list) == 1:  # Single layer.
    cell = cell_list[0]
  else:  # Multiple layers.
    cell = contrib_rnn.MultiRNNCell(cell_list)

  # Wrap with attention, if necessary.
  if attention_mechanism is not None:
    cell = contrib_seq2seq.AttentionWrapper(
        cell, [attention_mechanism] * attention_num_heads,
        attention_layer_size=[attention_layer_size] * attention_num_heads,
        alignment_history=False,
        output_attention=output_attention,
        name="attention")

  return cell 
Example #19
Source File: model_utils.py    From language with Apache License 2.0 4 votes vote down vote up
def create_gnmt_rnn_cell(unit_type,
                         num_units,
                         num_layers,
                         num_residual_layers,
                         forget_bias,
                         dropout,
                         mode,
                         attention_mechanism,
                         attention_num_heads=1,
                         attention_layer_size=None,
                         output_attention=False,
                         single_cell_fn=None):
  """Returns an instance of an GNMT-style RNN cell.

  Args:
    unit_type: A string that specifies the type of the recurrent unit. Must be
      one of {"lstm", "gru", "lstm_norm", "nas"}.
    num_units: An integer for the numner of units per layer.
    num_layers: An integer for the number of recurrent layers.
    num_residual_layers: An integer for the number of residual layers.
    forget_bias: A float for the forget bias in LSTM cells.
    dropout: A float for the recurrent dropout rate.
    mode: TRAIN | EVAL | PREDICT
    attention_mechanism: An instance of tf.contrib.seq2seq.AttentionMechanism.
    attention_num_heads: An integer for the number of attention heads.
    attention_layer_size: Optional integer for the size of the attention layer.
    output_attention: A boolean indicating whether RNN cell outputs attention.
    single_cell_fn: A function for building a single RNN cell.

  Returns:
    An RNNCell instance.
  """
  cell_list = _cell_list(
      unit_type=unit_type,
      num_units=num_units,
      num_layers=num_layers,
      forget_bias=forget_bias,
      dropout=dropout,
      mode=mode,
      num_residual_layers=num_residual_layers,
      single_cell_fn=single_cell_fn,
      residual_fn=gnmt_residual_fn)

  if attention_num_heads > 1:
    attention_mechanism = [attention_mechanism] * attention_num_heads
    attention_layer_size = [attention_layer_size] * attention_num_heads

  # Only wrap the bottom layer with the attention mechanism.
  attention_cell = cell_list.pop(0)
  attention_cell = contrib_seq2seq.AttentionWrapper(
      attention_cell,
      attention_mechanism,
      attention_layer_size=attention_layer_size,
      alignment_history=False,
      output_attention=output_attention,
      name="gnmt_attention")

  cell = GNMTAttentionMultiCell(
      attention_cell, cell_list, use_new_attention=True)

  return cell 
Example #20
Source File: rnn_decoders.py    From texar with Apache License 2.0 4 votes vote down vote up
def __init__(self,
                 memory,
                 memory_sequence_length=None,
                 cell=None,
                 cell_dropout_mode=None,
                 vocab_size=None,
                 output_layer=None,
                 # attention_layer=None, # TODO(zhiting): only valid for tf>=1.0
                 cell_input_fn=None,
                 hparams=None):
        RNNDecoderBase.__init__(
            self, cell, vocab_size, output_layer, cell_dropout_mode, hparams)

        attn_hparams = self._hparams['attention']
        attn_kwargs = attn_hparams['kwargs'].todict()

        # Parse the 'probability_fn' argument
        if 'probability_fn' in attn_kwargs:
            prob_fn = attn_kwargs['probability_fn']
            if prob_fn is not None and not callable(prob_fn):
                prob_fn = utils.get_function(
                    prob_fn,
                    ['tensorflow.nn', 'tensorflow.contrib.sparsemax',
                     'tensorflow.contrib.seq2seq'])
            attn_kwargs['probability_fn'] = prob_fn

        attn_kwargs.update({
            "memory_sequence_length": memory_sequence_length,
            "memory": memory})
        self._attn_kwargs = attn_kwargs
        attn_modules = ['tensorflow.contrib.seq2seq', 'texar.tf.custom']
        # Use variable_scope to ensure all trainable variables created in
        # the attention mechanism are collected
        with tf.variable_scope(self.variable_scope):
            attention_mechanism = utils.check_or_get_instance(
                attn_hparams["type"], attn_kwargs, attn_modules,
                classtype=tf.contrib.seq2seq.AttentionMechanism)

        self._attn_cell_kwargs = {
            "attention_layer_size": attn_hparams["attention_layer_size"],
            "alignment_history": attn_hparams["alignment_history"],
            "output_attention": attn_hparams["output_attention"],
        }
        self._cell_input_fn = cell_input_fn
        # Use variable_scope to ensure all trainable variables created in
        # AttentionWrapper are collected
        with tf.variable_scope(self.variable_scope):
            # if attention_layer is not None:
            #    self._attn_cell_kwargs["attention_layer_size"] = None
            attn_cell = AttentionWrapper(
                self._cell,
                attention_mechanism,
                cell_input_fn=self._cell_input_fn,
                # attention_layer=attention_layer,
                **self._attn_cell_kwargs)
            self._cell = attn_cell 
Example #21
Source File: attention.py    From avsr-tf1 with GNU General Public License v3.0 4 votes vote down vote up
def add_attention(
        cells,
        attention_types,
        num_units,
        memory,
        memory_len,
        mode,
        batch_size,
        dtype,
        beam_search=False,
        beam_width=None,
        initial_state=None,
        write_attention_alignment=False,
        fusion_type='linear_fusion',
):
    r"""
    Wraps the decoder_cells with an AttentionWrapper
    Args:
        cells: instances of `RNNCell`
        beam_search: `bool` flag for beam search decoders
        batch_size: `Tensor` containing the batch size. Necessary to the initialisation of the initial state

    Returns:
        attention_cells: the Attention wrapped decoder cells
        initial_state: a proper initial state to be used with the returned cells
    """
    attention_mechanisms, attention_layers, attention_layer_sizes, output_attention = create_attention_mechanisms(
        beam_search=beam_search,
        beam_width=beam_width,
        memory=memory,
        memory_len=memory_len,
        num_units=num_units,
        attention_types=attention_types,
        fusion_type=fusion_type,
        mode=mode,
        dtype=dtype)

    if beam_search is True:
        initial_state= seq2seq.tile_batch(
            initial_state, multiplier=beam_width)

    attention_cells = seq2seq.AttentionWrapper(
        cell=cells,
        attention_mechanism=attention_mechanisms,
        attention_layer_size=attention_layer_sizes,
        # initial_cell_state=decoder_initial_state,
        alignment_history=write_attention_alignment,
        output_attention=output_attention,
        attention_layer=attention_layers,
    )

    attn_zero = attention_cells.zero_state(
        dtype=dtype,
        batch_size=batch_size * beam_width if beam_search is True else batch_size)

    if initial_state is not None:
        initial_state = attn_zero.clone(
            cell_state=initial_state)

    return attention_cells, initial_state 
Example #22
Source File: seq2seq.py    From retrosynthesis_planner with GNU General Public License v3.0 4 votes vote down vote up
def _make_decoder(self, encoder_outputs, encoder_final_state, beam_search=False, reuse=False):
        """Create decoder"""
        with tf.variable_scope('decode', reuse=reuse):
            # Create decoder cells
            cells = [self._make_cell() for _ in range(self.depth)]

            if beam_search:
                # Tile inputs as needed for beam search
                encoder_outputs = seq2seq.tile_batch(
                    encoder_outputs, multiplier=self.beam_width)
                encoder_final_state = nest.map_structure(
                    lambda s: seq2seq.tile_batch(s, multiplier=self.beam_width),
                    encoder_final_state)
                sequence_length = seq2seq.tile_batch(
                    self.sequence_length, multiplier=self.beam_width)
            else:
                sequence_length = self.sequence_length

            # Prepare attention mechanism;
            # add only to last cell
            attention_mechanism = seq2seq.LuongAttention(
                num_units=self.hidden_size, memory=encoder_outputs,
                memory_sequence_length=sequence_length, name='attn')
            cells[-1] = seq2seq.AttentionWrapper(
                cells[-1], attention_mechanism, attention_layer_size=self.hidden_size,
                initial_cell_state=encoder_final_state[-1],
                cell_input_fn=lambda inp, attn: tf.layers.dense(tf.concat([inp, attn], -1), self.hidden_size),
                name='attnwrap'
            )

            # Copy encoder final state as decoder initial state
            decoder_initial_state = [s for s in encoder_final_state]

            # Set last initial state to be AttentionWrapperState
            batch_size = self.batch_size
            if beam_search: batch_size = self.batch_size * self.beam_width
            decoder_initial_state[-1] = cells[-1].zero_state(
                dtype=tf.float32, batch_size=batch_size)

            # Wrap up the cells
            cell = rnn.MultiRNNCell(cells)

            # Return initial state as a tuple
            # (required by tensorflow)
            return cell, tuple(decoder_initial_state) 
Example #23
Source File: _rnn.py    From DeepChatModels with MIT License 4 votes vote down vote up
def __init__(self,
                 cell,
                 attention_mechanism,
                 initial_cell_state=None,
                 name=None):
        """Construct the wrapper.
        
        Main tweak is creating the attention_layer with a tanh activation 
        (Luong's choice) as opposed to linear (TensorFlow's choice). Also,
        since I am sticking with Luong's approach, parameters that are in the
        constructor of TensorFlow's AttentionWrapper have been removed, and 
        the corresponding values are set to how Luong's paper defined them.
        
        Args:
            cell: instance of the Cell class above.
            attention_mechanism: instance of tf AttentionMechanism.
            initial_cell_state: The initial state value to use for the cell when
                the user calls `zero_state()`.
            name: Name to use when creating ops.
        """

        super(SimpleAttentionWrapper, self).__init__(name=name)

        # Assume that 'cell' is an instance of the custom 'Cell' class above.
        self._base_cell = cell._base_cell
        self._num_layers = cell._num_layers
        self._state_size = cell._state_size

        self._attention_size = attention_mechanism.values.get_shape()[-1].value
        self._attention_layer = layers_core.Dense(self._attention_size,
                                                  activation=tf.nn.tanh,
                                                  name="attention_layer",
                                                  use_bias=False)

        self._cell = cell
        self._attention_mechanism = attention_mechanism
        with tf.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (
                    final_state_tensor.shape[0].value
                    or tf.shape(final_state_tensor)[0])
                error_message = (
                    "Constructor AttentionWrapper %s: " % self._base_name +
                    "Non-matching batch sizes between the memory "
                    "(encoder output) and initial_cell_state.")
                with tf.control_dependencies(
                    [tf.assert_equal(state_batch_size,
                        self._attention_mechanism.batch_size,
                        message=error_message)]):
                    self._initial_cell_state = nest.map_structure(
                        lambda s: tf.identity(s, name="check_initial_cell_state"),
                        initial_cell_state)