Python tensorflow.contrib.seq2seq.AttentionWrapperState() Examples

The following are 9 code examples of tensorflow.contrib.seq2seq.AttentionWrapperState(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.contrib.seq2seq , or try the search function .
Example #1
Source File: decoders.py    From language with Apache License 2.0 6 votes vote down vote up
def _build_init_state(self, batch_size, enc_state, rnn_cell, mode, hparams):
    """Builds initial states for the given RNN cells."""
    del mode  # Unused.

    # Build init state.
    init_state = rnn_cell.zero_state(batch_size, tf.float32)

    if hparams.pass_hidden_state:
      # Non-GNMT RNN cell returns AttentionWrappedState.
      if isinstance(init_state, contrib_seq2seq.AttentionWrapperState):
        init_state = init_state.clone(cell_state=enc_state)
      # GNMT RNN cell returns a tuple state.
      elif isinstance(init_state, tuple):
        init_state = tuple(
            zs.clone(cell_state=es) if isinstance(
                zs, contrib_seq2seq.AttentionWrapperState) else es
            for zs, es in zip(init_state, enc_state))
      else:
        ValueError("RNN cell returns zero states of unknown type: %s"
                   % str(type(init_state)))

    return init_state 
Example #2
Source File: decoders.py    From language with Apache License 2.0 6 votes vote down vote up
def _build_init_state(self, batch_size, enc_state, rnn_cell, mode, hparams):
    """Builds initial states for the given RNN cells."""
    # Build init state.
    init_state = rnn_cell.zero_state(batch_size, tf.float32)
    inner_state = init_state.cell_state

    if hparams.pass_hidden_state:
      # Non-GNMT RNN cell returns AttentionWrappedState.
      if isinstance(inner_state, contrib_seq2seq.AttentionWrapperState):
        init_state = init_state.clone(
            cell_state=inner_state.clone(cell_state=enc_state))
      # GNMT RNN cell returns a tuple state.
      elif isinstance(init_state.cell_state, tuple):
        init_state = init_state.clone(
            cell_state=tuple(
                zs.clone(cell_state=es) if isinstance(
                    zs, contrib_seq2seq.AttentionWrapperState) else es
                for zs, es in zip(inner_state, enc_state)))
      else:
        ValueError("RNN cell returns zero states of unknown type: %s"
                   % str(type(init_state)))

    return init_state 
Example #3
Source File: _rnn.py    From DeepChatModels with MIT License 5 votes vote down vote up
def wrapper(self, state):
        """Some RNN states are wrapped in namedtuples. 
        (TensorFlow decision, definitely not mine...). 
        
        This is here for derived classes to specify their wrapper state. 
        Some examples: LSTMStateTuple and AttentionWrapperState.
        
        Args:
            state: tensor state tuple, will be unpacked into the wrapper tuple.
        """
        if self._wrapper is None:
            return state
        else:
            return self._wrapper(*state) 
Example #4
Source File: _rnn.py    From DeepChatModels with MIT License 5 votes vote down vote up
def zero_state(self, batch_size, dtype):
        with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
            if self._initial_cell_state is not None:
                cell_state = self._initial_cell_state
            else:
                cell_state = self._cell.zero_state(batch_size, dtype)
            error_message = (
                "zero_state of AttentionWrapper %s: " % self._base_name +
                "Non-matching batch sizes between the memory "
                "(encoder output) and the requested batch size.")
            with tf.control_dependencies(
                [tf.assert_equal(batch_size,
                    self._attention_mechanism.batch_size,
                    message=error_message)]):
                cell_state = nest.map_structure(
                    lambda s: tf.identity(s, name="checked_cell_state"),
                    cell_state)
            alignment_history = ()

            _zero_state_tensors = rnn_cell_impl._zero_state_tensors
            return AttentionWrapperState(
                cell_state=cell_state,
                time=tf.zeros([], dtype=tf.int32),
                attention=_zero_state_tensors(self._attention_size, batch_size,
                dtype),
                alignments=self._attention_mechanism.initial_alignments(
                    batch_size, dtype),
                alignment_history=alignment_history) 
Example #5
Source File: _rnn.py    From DeepChatModels with MIT License 5 votes vote down vote up
def state_size(self):
        return AttentionWrapperState(
            cell_state=self._cell.state_size,
            attention=self._attention_size,
            time=tf.TensorShape([]),
            alignments=self._attention_mechanism.alignments_size,
            alignment_history=()) 
Example #6
Source File: _rnn.py    From DeepChatModels with MIT License 5 votes vote down vote up
def shape(self):
        return AttentionWrapperState(
            cell_state=self._cell.shape,
            attention=tf.TensorShape([None, self._attention_size]),
            time=tf.TensorShape(None),
            alignments=tf.TensorShape([None, None]),
            alignment_history=()) 
Example #7
Source File: _rnn.py    From DeepChatModels with MIT License 4 votes vote down vote up
def call(self, inputs, state):
        """First computes the cell state and output in the usual way, 
        then works through the attention pipeline:
            h --> a --> c --> h_tilde
        using the naming/notation from Luong et. al, 2015.

        Args:
            inputs: `2-D` tensor with shape `[batch_size x input_size]`.
            state: An instance of `AttentionWrapperState` containing the 
                tensors from the prev timestep.
     
        Returns:
            A tuple `(attention_or_cell_output, next_state)`, where:
            - `attention_or_cell_output` depending on `output_attention`.
            - `next_state` is an instance of `DynamicAttentionWrapperState`
                containing the state calculated at this time step.
        """

        # Concatenate the previous h_tilde with inputs (input-feeding).
        cell_inputs = tf.concat([inputs, state.attention], -1)

        # 1. (hidden) Compute the hidden state (cell_output).
        cell_output, next_cell_state = self._cell(cell_inputs,
                                                  state.cell_state)

        # 2. (align) Compute the normalized alignment scores. [B, L_enc].
        # where L_enc is the max seq len in the encoder outputs for the (B)atch.
        score = self._attention_mechanism(
            cell_output, previous_alignments=state.alignments)
        alignments = tf.nn.softmax(score)

        # Reshape from [B, L_enc] to [B, 1, L_enc]
        expanded_alignments = tf.expand_dims(alignments, 1)
        # (Possibly projected) encoder outputs: [B, L_enc, state_size]
        encoder_outputs = self._attention_mechanism.values
        # 3 (context) Take inner prod. [B, 1, state size].
        context = tf.matmul(expanded_alignments, encoder_outputs)
        context = tf.squeeze(context, [1])

        # 4 (h_tilde) Compute tanh(W [c, h]).
        attention = self._attention_layer(
            tf.concat([cell_output, context], -1))

        next_state = AttentionWrapperState(
            cell_state=next_cell_state,
            attention=attention,
            time=state.time + 1,
            alignments=alignments,
            alignment_history=())

        return attention, next_state 
Example #8
Source File: sync_attention_wrapper.py    From aster with MIT License 4 votes vote down vote up
def call(self, inputs, state):
    if not isinstance(state, seq2seq.AttentionWrapperState):
      raise TypeError("Expected state to be instance of AttentionWrapperState. "
                      "Received type %s instead."  % type(state))

    if self._is_multi:
      previous_alignments = state.alignments
      previous_alignment_history = state.alignment_history
    else:
      previous_alignments = [state.alignments]
      previous_alignment_history = [state.alignment_history]

    all_alignments = []
    all_attentions = []
    all_histories = []
    for i, attention_mechanism in enumerate(self._attention_mechanisms):
      if isinstance(self._cell, rnn.LSTMCell):
        rnn_cell_state = state.cell_state.h
      else:
        rnn_cell_state = state.cell_state
      attention, alignments = _compute_attention(
          attention_mechanism, rnn_cell_state, previous_alignments[i],
          self._attention_layers[i] if self._attention_layers else None)
      alignment_history = previous_alignment_history[i].write(
          state.time, alignments) if self._alignment_history else ()

      all_alignments.append(alignments)
      all_histories.append(alignment_history)
      all_attentions.append(attention)

    attention = array_ops.concat(all_attentions, 1)

    cell_inputs = self._cell_input_fn(inputs, attention)
    cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state)

    next_state = seq2seq.AttentionWrapperState(
        time=state.time + 1,
        cell_state=next_cell_state,
        attention=attention,
        alignments=self._item_or_tuple(all_alignments),
        alignment_history=self._item_or_tuple(all_histories))
    
    if self._output_attention:
      return attention, next_state
    else:
      return cell_output, next_state 
Example #9
Source File: sync_attention_wrapper.py    From AON with MIT License 4 votes vote down vote up
def call(self, inputs, state):
    if not isinstance(state, seq2seq.AttentionWrapperState):
      raise TypeError("Expected state to be instance of AttentionWrapperState. "
                      "Received type %s instead."  % type(state))

    if self._is_multi:
      previous_alignments = state.alignments
      previous_alignment_history = state.alignment_history
    else:
      previous_alignments = [state.alignments]
      previous_alignment_history = [state.alignment_history]

    all_alignments = []
    all_attentions = []
    all_attention_states = []
    all_histories = []
    for i, attention_mechanism in enumerate(self._attention_mechanisms):
      if isinstance(self._cell, rnn.LSTMCell):
        rnn_cell_state = state.cell_state.h
      else:
        rnn_cell_state = state.cell_state
      attention, alignments, next_attention_state = _compute_attention(
          attention_mechanism, rnn_cell_state, previous_alignments[i],
          self._attention_layers[i] if self._attention_layers else None)
      alignment_history = previous_alignment_history[i].write(
          state.time, alignments) if self._alignment_history else ()

      all_attention_states.append(next_attention_state)
      all_alignments.append(alignments)
      all_histories.append(alignment_history)
      all_attentions.append(attention)

    attention = array_ops.concat(all_attentions, 1)

    cell_inputs = self._cell_input_fn(inputs, attention)
    cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state)

    next_state = seq2seq.AttentionWrapperState(
        time=state.time + 1,
        cell_state=next_cell_state,
        attention=attention,
        attention_state=self._item_or_tuple(all_attention_states),
        alignments=self._item_or_tuple(all_alignments),
        alignment_history=self._item_or_tuple(all_histories))
    
    if self._output_attention:
      return attention, next_state
    else:
      return cell_output, next_state