Python tensorflow.contrib.seq2seq.AttentionWrapperState() Examples
The following are 9
code examples of tensorflow.contrib.seq2seq.AttentionWrapperState().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.contrib.seq2seq
, or try the search function
.
Example #1
Source File: decoders.py From language with Apache License 2.0 | 6 votes |
def _build_init_state(self, batch_size, enc_state, rnn_cell, mode, hparams): """Builds initial states for the given RNN cells.""" del mode # Unused. # Build init state. init_state = rnn_cell.zero_state(batch_size, tf.float32) if hparams.pass_hidden_state: # Non-GNMT RNN cell returns AttentionWrappedState. if isinstance(init_state, contrib_seq2seq.AttentionWrapperState): init_state = init_state.clone(cell_state=enc_state) # GNMT RNN cell returns a tuple state. elif isinstance(init_state, tuple): init_state = tuple( zs.clone(cell_state=es) if isinstance( zs, contrib_seq2seq.AttentionWrapperState) else es for zs, es in zip(init_state, enc_state)) else: ValueError("RNN cell returns zero states of unknown type: %s" % str(type(init_state))) return init_state
Example #2
Source File: decoders.py From language with Apache License 2.0 | 6 votes |
def _build_init_state(self, batch_size, enc_state, rnn_cell, mode, hparams): """Builds initial states for the given RNN cells.""" # Build init state. init_state = rnn_cell.zero_state(batch_size, tf.float32) inner_state = init_state.cell_state if hparams.pass_hidden_state: # Non-GNMT RNN cell returns AttentionWrappedState. if isinstance(inner_state, contrib_seq2seq.AttentionWrapperState): init_state = init_state.clone( cell_state=inner_state.clone(cell_state=enc_state)) # GNMT RNN cell returns a tuple state. elif isinstance(init_state.cell_state, tuple): init_state = init_state.clone( cell_state=tuple( zs.clone(cell_state=es) if isinstance( zs, contrib_seq2seq.AttentionWrapperState) else es for zs, es in zip(inner_state, enc_state))) else: ValueError("RNN cell returns zero states of unknown type: %s" % str(type(init_state))) return init_state
Example #3
Source File: _rnn.py From DeepChatModels with MIT License | 5 votes |
def wrapper(self, state): """Some RNN states are wrapped in namedtuples. (TensorFlow decision, definitely not mine...). This is here for derived classes to specify their wrapper state. Some examples: LSTMStateTuple and AttentionWrapperState. Args: state: tensor state tuple, will be unpacked into the wrapper tuple. """ if self._wrapper is None: return state else: return self._wrapper(*state)
Example #4
Source File: _rnn.py From DeepChatModels with MIT License | 5 votes |
def zero_state(self, batch_size, dtype): with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: cell_state = self._cell.zero_state(batch_size, dtype) error_message = ( "zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and the requested batch size.") with tf.control_dependencies( [tf.assert_equal(batch_size, self._attention_mechanism.batch_size, message=error_message)]): cell_state = nest.map_structure( lambda s: tf.identity(s, name="checked_cell_state"), cell_state) alignment_history = () _zero_state_tensors = rnn_cell_impl._zero_state_tensors return AttentionWrapperState( cell_state=cell_state, time=tf.zeros([], dtype=tf.int32), attention=_zero_state_tensors(self._attention_size, batch_size, dtype), alignments=self._attention_mechanism.initial_alignments( batch_size, dtype), alignment_history=alignment_history)
Example #5
Source File: _rnn.py From DeepChatModels with MIT License | 5 votes |
def state_size(self): return AttentionWrapperState( cell_state=self._cell.state_size, attention=self._attention_size, time=tf.TensorShape([]), alignments=self._attention_mechanism.alignments_size, alignment_history=())
Example #6
Source File: _rnn.py From DeepChatModels with MIT License | 5 votes |
def shape(self): return AttentionWrapperState( cell_state=self._cell.shape, attention=tf.TensorShape([None, self._attention_size]), time=tf.TensorShape(None), alignments=tf.TensorShape([None, None]), alignment_history=())
Example #7
Source File: _rnn.py From DeepChatModels with MIT License | 4 votes |
def call(self, inputs, state): """First computes the cell state and output in the usual way, then works through the attention pipeline: h --> a --> c --> h_tilde using the naming/notation from Luong et. al, 2015. Args: inputs: `2-D` tensor with shape `[batch_size x input_size]`. state: An instance of `AttentionWrapperState` containing the tensors from the prev timestep. Returns: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `DynamicAttentionWrapperState` containing the state calculated at this time step. """ # Concatenate the previous h_tilde with inputs (input-feeding). cell_inputs = tf.concat([inputs, state.attention], -1) # 1. (hidden) Compute the hidden state (cell_output). cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state) # 2. (align) Compute the normalized alignment scores. [B, L_enc]. # where L_enc is the max seq len in the encoder outputs for the (B)atch. score = self._attention_mechanism( cell_output, previous_alignments=state.alignments) alignments = tf.nn.softmax(score) # Reshape from [B, L_enc] to [B, 1, L_enc] expanded_alignments = tf.expand_dims(alignments, 1) # (Possibly projected) encoder outputs: [B, L_enc, state_size] encoder_outputs = self._attention_mechanism.values # 3 (context) Take inner prod. [B, 1, state size]. context = tf.matmul(expanded_alignments, encoder_outputs) context = tf.squeeze(context, [1]) # 4 (h_tilde) Compute tanh(W [c, h]). attention = self._attention_layer( tf.concat([cell_output, context], -1)) next_state = AttentionWrapperState( cell_state=next_cell_state, attention=attention, time=state.time + 1, alignments=alignments, alignment_history=()) return attention, next_state
Example #8
Source File: sync_attention_wrapper.py From aster with MIT License | 4 votes |
def call(self, inputs, state): if not isinstance(state, seq2seq.AttentionWrapperState): raise TypeError("Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): if isinstance(self._cell, rnn.LSTMCell): rnn_cell_state = state.cell_state.h else: rnn_cell_state = state.cell_state attention, alignments = _compute_attention( attention_mechanism, rnn_cell_state, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) attention = array_ops.concat(all_attentions, 1) cell_inputs = self._cell_input_fn(inputs, attention) cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state) next_state = seq2seq.AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories)) if self._output_attention: return attention, next_state else: return cell_output, next_state
Example #9
Source File: sync_attention_wrapper.py From AON with MIT License | 4 votes |
def call(self, inputs, state): if not isinstance(state, seq2seq.AttentionWrapperState): raise TypeError("Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_attention_states = [] all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): if isinstance(self._cell, rnn.LSTMCell): rnn_cell_state = state.cell_state.h else: rnn_cell_state = state.cell_state attention, alignments, next_attention_state = _compute_attention( attention_mechanism, rnn_cell_state, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_attention_states.append(next_attention_state) all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) attention = array_ops.concat(all_attentions, 1) cell_inputs = self._cell_input_fn(inputs, attention) cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state) next_state = seq2seq.AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, attention_state=self._item_or_tuple(all_attention_states), alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories)) if self._output_attention: return attention, next_state else: return cell_output, next_state