Python tensorflow.python.util.nest.is_sequence() Examples
The following are 30
code examples of tensorflow.python.util.nest.is_sequence().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.util.nest
, or try the search function
.
Example #1
Source File: rnn_cell.py From ROLO with Apache License 2.0 | 6 votes |
def __init__(self, cells, state_is_tuple=True): """Create a RNN cell composed sequentially of a number of RNNCells. Args: cells: list of RNNCells that will be composed in this order. state_is_tuple: If True, accepted and returned states are n-tuples, where `n = len(cells)`. If False, the states are all concatenated along the column axis. This latter behavior will soon be deprecated. Raises: ValueError: if cells is empty (not allowed), or at least one of the cells returns a state tuple but the flag `state_is_tuple` is `False`. """ if not cells: raise ValueError("Must specify at least one cell for MultiRNNCell.") self._cells = cells self._state_is_tuple = state_is_tuple if not state_is_tuple: if any(nest.is_sequence(c.state_size) for c in self._cells): raise ValueError("Some cells return tuples of states, but the flag " "state_is_tuple is not set. State sizes are: %s" % str([c.state_size for c in self._cells]))
Example #2
Source File: FastWeightsRNN.py From AssociativeRetrieval with Apache License 2.0 | 6 votes |
def _fwlinear(self, args, output_size, scope=None): if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] assert len(args) == 2 assert args[0].get_shape().as_list()[1] == output_size dtype = [a.dtype for a in args][0] with vs.variable_scope(scope or "Linear"): matrixW = vs.get_variable( "MatrixW", dtype=dtype, initializer=tf.convert_to_tensor(np.eye(output_size, dtype=np.float32) * .05)) matrixC = vs.get_variable( "MatrixC", [args[1].get_shape().as_list()[1], output_size], dtype=dtype) res = tf.matmul(args[0], matrixW) + tf.matmul(args[1], matrixC) return res
Example #3
Source File: nn.py From AmusingPythonCodes with MIT License | 6 votes |
def linear(args, output_size, bias, bias_start=0.0, scope=None, squeeze=False, wd=0.0, input_keep_prob=1.0, is_train=None): if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] flat_args = [flatten(arg, 1) for arg in args] if input_keep_prob < 1.0: assert is_train is not None flat_args = [tf.cond(is_train, lambda: tf.nn.dropout(arg, input_keep_prob), lambda: arg) for arg in flat_args] with tf.variable_scope(scope or 'Linear'): flat_out = _linear(flat_args, output_size, bias, bias_initializer=tf.constant_initializer(bias_start)) out = reconstruct(flat_out, args[0], 1) if squeeze: out = tf.squeeze(out, [len(args[0].get_shape().as_list())-1]) if wd: add_wd(wd) return out
Example #4
Source File: rnn_cell_impl.py From lambda-packs with MIT License | 6 votes |
def call(self, inputs, state): """Run this multi-layer cell on inputs, starting from state.""" cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with vs.variable_scope("cell_%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice(state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(new_states, 1)) return cur_inp, new_states
Example #5
Source File: mod_core_rnn_cell_impl.py From RGAN with MIT License | 6 votes |
def __call__(self, inputs, state, scope=None): """Run this multi-layer cell on inputs, starting from state.""" with vs.variable_scope(scope or "multi_rnn_cell"): cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with vs.variable_scope("cell_%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice( state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(new_states, 1)) return cur_inp, new_states
Example #6
Source File: core_rnn_cell_impl.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def __init__(self, cells, state_is_tuple=True): """Create a RNN cell composed sequentially of a number of RNNCells. Args: cells: list of RNNCells that will be composed in this order. state_is_tuple: If True, accepted and returned states are n-tuples, where `n = len(cells)`. If False, the states are all concatenated along the column axis. This latter behavior will soon be deprecated. Raises: ValueError: if cells is empty (not allowed), or at least one of the cells returns a state tuple but the flag `state_is_tuple` is `False`. """ if not cells: raise ValueError("Must specify at least one cell for MultiRNNCell.") self._cells = cells self._state_is_tuple = state_is_tuple if not state_is_tuple: if any(nest.is_sequence(c.state_size) for c in self._cells): raise ValueError("Some cells return tuples of states, but the flag " "state_is_tuple is not set. State sizes are: %s" % str([c.state_size for c in self._cells]))
Example #7
Source File: core_rnn_cell_impl.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def __call__(self, inputs, state, scope=None): """Run this multi-layer cell on inputs, starting from state.""" with vs.variable_scope(scope or "multi_rnn_cell"): cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with vs.variable_scope("cell_%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice( state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(new_states, 1)) return cur_inp, new_states
Example #8
Source File: deeppyramid_utils.py From BERT with Apache License 2.0 | 6 votes |
def linear(args, output_size, bias, bias_start=0.0, scope=None, squeeze=False, wd=0.0, input_keep_prob=1.0, is_train=None): with tf.variable_scope(scope or "linear"): if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] flat_args = [flatten(arg, 1) for arg in args] # if input_keep_prob < 1.0: assert is_train is not None flat_args = [tf.cond(is_train, lambda: tf.nn.dropout(arg, input_keep_prob), lambda: arg) for arg in flat_args] flat_out = _linear(flat_args, output_size, bias) out = reconstruct(flat_out, args[0], 1) if squeeze: out = tf.squeeze(out, [len(args[0].get_shape().as_list())-1]) return out
Example #9
Source File: nn.py From BERT with Apache License 2.0 | 6 votes |
def linear(args, output_size, bias, bias_start=0.0, scope=None, squeeze=False, wd=0.0, input_keep_prob=1.0, is_train=None): with tf.variable_scope(scope or "linear"): if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] flat_args = [flatten(arg, 1) for arg in args] # if input_keep_prob < 1.0: assert is_train is not None flat_args = [tf.cond(is_train, lambda: tf.nn.dropout(arg, input_keep_prob), lambda: arg) for arg in flat_args] flat_out = _linear(flat_args, output_size, bias) out = reconstruct(flat_out, args[0], 1) if squeeze: out = tf.squeeze(out, [len(args[0].get_shape().as_list())-1]) if wd: add_wd(wd) return out
Example #10
Source File: rnn_cell.py From deep_image_model with Apache License 2.0 | 6 votes |
def __call__(self, inputs, state, scope=None): """Run this multi-layer cell on inputs, starting from state.""" with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell" cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with vs.variable_scope("Cell%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice( state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(1, new_states)) return cur_inp, new_states
Example #11
Source File: rnn_cell.py From deep_image_model with Apache License 2.0 | 6 votes |
def __init__(self, cells, state_is_tuple=True): """Create a RNN cell composed sequentially of a number of RNNCells. Args: cells: list of RNNCells that will be composed in this order. state_is_tuple: If True, accepted and returned states are n-tuples, where `n = len(cells)`. If False, the states are all concatenated along the column axis. This latter behavior will soon be deprecated. Raises: ValueError: if cells is empty (not allowed), or at least one of the cells returns a state tuple but the flag `state_is_tuple` is `False`. """ if not cells: raise ValueError("Must specify at least one cell for MultiRNNCell.") self._cells = cells self._state_is_tuple = state_is_tuple if not state_is_tuple: if any(nest.is_sequence(c.state_size) for c in self._cells): raise ValueError("Some cells return tuples of states, but the flag " "state_is_tuple is not set. State sizes are: %s" % str([c.state_size for c in self._cells]))
Example #12
Source File: tf_utils.py From spl with GNU General Public License v3.0 | 6 votes |
def call(self, inputs, state): """Run this multi-layer cell on inputs, starting from state.""" cur_state_pos = 0 cur_inp = inputs new_states = [] new_outputs = [] for i, cell in enumerate(self._cells): with vs.variable_scope("cell_%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError("Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice(state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_outputs.append(cur_inp) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(new_states, 1)) if self._intermediate_outputs: new_outputs = (tuple(new_outputs) if self._state_is_tuple else array_ops.concat(new_outputs, 1)) return new_outputs, new_states else: return cur_inp, new_states
Example #13
Source File: eval_tools.py From hart with GNU General Public License v3.0 | 6 votes |
def log_values(writer, itr, tags=None, values=None, dict=None): if dict is not None: assert tags is None and values is None tags = dict.keys() values = dict.values() else: if not nest.is_sequence(tags): tags, values = [tags], [values] elif len(tags) != len(values): raise ValueError('tag and value have different lenghts:' ' {} vs {}'.format(len(tags), len(values))) for t, v in zip(tags, values): summary = tf.Summary.Value(tag=t, simple_value=v) summary = tf.Summary(value=[summary]) writer.add_summary(summary, itr)
Example #14
Source File: nested_utils.py From Gun-Detector with Apache License 2.0 | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = map(map_fn, nest.flatten(nested)) return nest.pack_sequence_as(nested, out)
Example #15
Source File: nested_utils.py From yolo_v2 with Apache License 2.0 | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = map(map_fn, nest.flatten(nested)) return nest.pack_sequence_as(nested, out)
Example #16
Source File: nested_utils.py From MultitaskAIS with MIT License | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = list(map(map_fn, nest.flatten(nested))) return nest.pack_sequence_as(nested, out)
Example #17
Source File: rnn_cell.py From ROLO with Apache License 2.0 | 6 votes |
def __call__(self, inputs, state, scope=None): """Run this multi-layer cell on inputs, starting from state.""" with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell" cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with vs.variable_scope("Cell%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice( state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(1, new_states)) return cur_inp, new_states
Example #18
Source File: rnn_cell.py From ecm with Apache License 2.0 | 6 votes |
def __init__(self, cells, state_is_tuple=True): """Create a RNN cell composed sequentially of a number of RNNCells. Args: cells: list of RNNCells that will be composed in this order. state_is_tuple: If True, accepted and returned states are n-tuples, where `n = len(cells)`. If False, the states are all concatenated along the column axis. This latter behavior will soon be deprecated. Raises: ValueError: if cells is empty (not allowed), or at least one of the cells returns a state tuple but the flag `state_is_tuple` is `False`. """ if not cells: raise ValueError("Must specify at least one cell for MultiRNNCell.") self._cells = cells self._state_is_tuple = state_is_tuple if not state_is_tuple: if any(nest.is_sequence(c.state_size) for c in self._cells): raise ValueError("Some cells return tuples of states, but the flag " "state_is_tuple is not set. State sizes are: %s" % str([c.state_size for c in self._cells]))
Example #19
Source File: mod_core_rnn_cell_impl.py From RGAN with MIT License | 5 votes |
def __init__(self, cells, state_is_tuple=True): """Create a RNN cell composed sequentially of a number of RNNCells. Args: cells: list of RNNCells that will be composed in this order. state_is_tuple: If True, accepted and returned states are n-tuples, where `n = len(cells)`. If False, the states are all concatenated along the column axis. This latter behavior will soon be deprecated. Raises: ValueError: if cells is empty (not allowed), or at least one of the cells returns a state tuple but the flag `state_is_tuple` is `False`. """ if not cells: raise ValueError("Must specify at least one cell for MultiRNNCell.") if not nest.is_sequence(cells): raise TypeError( "cells must be a list or tuple, but saw: %s." % cells) self._cells = cells self._state_is_tuple = state_is_tuple if not state_is_tuple: if any(nest.is_sequence(c.state_size) for c in self._cells): raise ValueError("Some cells return tuples of states, but the flag " "state_is_tuple is not set. State sizes are: %s" % str([c.state_size for c in self._cells]))
Example #20
Source File: rnn_cell.py From ecm with Apache License 2.0 | 5 votes |
def zero_state(self, batch_size, dtype): """Return zero-filled state tensor(s). Args: batch_size: int, float, or unit Tensor representing the batch size. dtype: the data type to use for the state. Returns: If `state_size` is an int or TensorShape, then the return value is a `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. If `state_size` is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of `2-D` tensors with the shapes `[batch_size x s]` for each s in `state_size`. """ state_size = self.state_size if nest.is_sequence(state_size): state_size_flat = nest.flatten(state_size) zeros_flat = [ array_ops.zeros( array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])), dtype=dtype) for s in state_size_flat] for s, z in zip(state_size_flat, zeros_flat): z.set_shape(_state_size_with_prefix(s, prefix=[None])) zeros = nest.pack_sequence_as(structure=state_size, flat_sequence=zeros_flat) else: zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype) zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None])) return zeros
Example #21
Source File: nn.py From AmusingPythonCodes with MIT License | 5 votes |
def sum_logits(args, mask=None, name=None): with tf.name_scope(name or "sum_logits"): if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] rank = len(args[0].get_shape()) logits = sum(tf.reduce_sum(arg, rank-1) for arg in args) if mask is not None: logits = exp_mask(logits, mask) return logits
Example #22
Source File: eval_tools.py From forge with GNU General Public License v3.0 | 5 votes |
def log_values(writer, itr, tags=None, values=None, dict=None): """Writes scalar summaries to Tensorboard. Values can be passed as either a list of string tags and float values, or as a dictionary. In the latter case, the keys are used as summary tags. :param writer: tf.summary.Filewriter :param itr: int, training iteration :param tags: list of strings or None :param values: list of floats or None :param dict: dict of {string: float} or None """ if dict is not None: assert tags is None and values is None tags = dict.keys() values = dict.values() else: if not nest.is_sequence(tags): tags, values = [tags], [values] elif len(tags) != len(values): raise ValueError('tag and value have different lenghts:' ' {} vs {}'.format(len(tags), len(values))) for t, v in zip(tags, values): summary = tf.Summary.Value(tag=t, simple_value=v) summary = tf.Summary(value=[summary]) writer.add_summary(summary, itr)
Example #23
Source File: rnn_cell.py From ecm with Apache License 2.0 | 5 votes |
def __call__(self, inputs, state, emotion, imemory, scope=None): """Run this multi-layer cell on inputs, starting from state.""" if emotion is None: x = [inputs] + [ i for i in state] else: x = [inputs, emotion] + [ i for i in state] if imemory is not None: with vs.variable_scope(scope or 'IMemoryReadGate'): r = sigmoid(_linear(x, imemory.get_shape().with_rank(2)[1], True, 1.0)) with vs.variable_scope(scope or 'MultiRNNCell'): # "MultiRNNCell" cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with vs.variable_scope("Cell%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice( state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size if i == 0: if imemory is None: cur_inp, new_state = cell(cur_inp, cur_state, emotion, imemory) else: cur_inp, new_state = cell(cur_inp, cur_state, emotion, r * imemory) else: cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(1, new_states)) new_imemory = imemory if imemory is not None: with vs.variable_scope(scope or 'IMemoryWriteGate'): w = sigmoid(_linear(new_states, imemory.get_shape().with_rank(2)[1], True, 1.0)) new_imemory = w * imemory return cur_inp, new_states, new_imemory
Example #24
Source File: rnn.py From deep_image_model with Apache License 2.0 | 5 votes |
def _infer_state_dtype(explicit_dtype, state): """Infer the dtype of an RNN state. Args: explicit_dtype: explicitly declared dtype or None. state: RNN's hidden state. Must be a Tensor or a nested iterable containing Tensors. Returns: dtype: inferred dtype of hidden state. Raises: ValueError: if `state` has heterogeneous dtypes or is empty. """ if explicit_dtype is not None: return explicit_dtype elif nest.is_sequence(state): inferred_dtypes = [element.dtype for element in nest.flatten(state)] if not inferred_dtypes: raise ValueError("Unable to infer dtype from empty state.") all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes]) if not all_same: raise ValueError( "State has tensors of different inferred_dtypes. Unable to infer a " "single representative dtype.") return inferred_dtypes[0] else: return state.dtype
Example #25
Source File: rnn.py From MIMN with MIT License | 5 votes |
def _infer_state_dtype(explicit_dtype, state): """Infer the dtype of an RNN state. Args: explicit_dtype: explicitly declared dtype or None. state: RNN's hidden state. Must be a Tensor or a nested iterable containing Tensors. Returns: dtype: inferred dtype of hidden state. Raises: ValueError: if `state` has heterogeneous dtypes or is empty. """ if explicit_dtype is not None: return explicit_dtype elif nest.is_sequence(state): inferred_dtypes = [element.dtype for element in nest.flatten(state)] if not inferred_dtypes: raise ValueError("Unable to infer dtype from empty state.") all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes]) if not all_same: raise ValueError( "State has tensors of different inferred_dtypes. Unable to infer a " "single representative dtype.") return inferred_dtypes[0] else: return state.dtype # pylint: disable=unused-argument
Example #26
Source File: rnn.py From RFL with MIT License | 5 votes |
def rnn(cell, inputs, initial_state, scope=None): """Creates a recurrent neural network specified by RNNCell `cell`.""" if not isinstance(cell, tf.contrib.rnn.RNNCell): raise TypeError("cell must be an instance of RNNCell") if not nest.is_sequence(inputs): raise TypeError("inputs must be a sequence") if not inputs: raise ValueError("inputs must not be empty") outputs = [] input_gates = [] forget_gates = [] output_gates = [] # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with tf.variable_scope(scope or "RNN") as varscope: if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) state = initial_state for time, input_ in enumerate(inputs): if time > 0: varscope.reuse_variables() call_cell = lambda: cell(input_, state) output, state, input_gate, forget_gate, output_gate = call_cell() outputs.append(output) input_gates.append(input_gate) forget_gates.append(forget_gate) output_gates.append(output_gate) return (outputs, state, input_gates, forget_gates, output_gates)
Example #27
Source File: gnmt_model.py From parallax with Apache License 2.0 | 5 votes |
def __call__(self, inputs, state, scope=None): """Run the cell with bottom layer's attention copied to all upper layers.""" if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) with tf.variable_scope(scope or "multi_rnn_cell"): new_states = [] with tf.variable_scope("cell_0_attention"): attention_cell = self._cells[0] attention_state = state[0] cur_inp, new_attention_state = attention_cell(inputs, attention_state) new_states.append(new_attention_state) for i in range(1, len(self._cells)): with tf.variable_scope("cell_%d" % i): cell = self._cells[i] cur_state = state[i] if self.use_new_attention: cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1) else: cur_inp = tf.concat([cur_inp, attention_state.attention], -1) cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) return cur_inp, tuple(new_states)
Example #28
Source File: bridges.py From NJUNMT-tf with Apache License 2.0 | 5 votes |
def _create(self, encoder_output, decoder_state_size, **kwargs): """ Creates decoder's initial RNN states according to `decoder_state_size`. Passes the final state of encoder to each layer in decoder. Args: encoder_output: An instance of `collections.namedtuple` from `Encoder.encode()`. decoder_state_size: RNN decoder state size. **kwargs: Returns: The decoder states with the structure determined by `decoder_state_size`. Raises: ValueError: if the structure of encoder RNN state does not have the same structure of decoder RNN state. """ batch_size = tf.shape(encoder_output.attention_length)[0] # of type LSTMStateTuple enc_final_state = _final_state( encoder_output.final_states, direction=self.params["direction"]) assert_state_is_compatible(rnn_cell_impl._zero_state_tensors( decoder_state_size[0], batch_size, tf.float32), enc_final_state) if nest.is_sequence(decoder_state_size): return tuple([enc_final_state for _ in decoder_state_size]) return enc_final_state
Example #29
Source File: gnmt_model.py From NETransliteration-COLING2018 with MIT License | 5 votes |
def __call__(self, inputs, state, scope=None): """Run the cell with bottom layer's attention copied to all upper layers.""" if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) with tf.variable_scope(scope or "multi_rnn_cell"): new_states = [] with tf.variable_scope("cell_0_attention"): attention_cell = self._cells[0] attention_state = state[0] cur_inp, new_attention_state = attention_cell(inputs, attention_state) new_states.append(new_attention_state) for i in range(1, len(self._cells)): with tf.variable_scope("cell_%d" % i): cell = self._cells[i] cur_state = state[i] if not isinstance(cur_state, tf.contrib.rnn.LSTMStateTuple): raise TypeError("`state[{}]` must be a LSTMStateTuple".format(i)) if self.use_new_attention: cur_state = cur_state._replace(h=tf.concat( [cur_state.h, new_attention_state.attention], 1)) else: cur_state = cur_state._replace(h=tf.concat( [cur_state.h, attention_state.attention], 1)) cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) return cur_inp, tuple(new_states)
Example #30
Source File: rnn.py From lambda-packs with MIT License | 5 votes |
def _infer_state_dtype(explicit_dtype, state): """Infer the dtype of an RNN state. Args: explicit_dtype: explicitly declared dtype or None. state: RNN's hidden state. Must be a Tensor or a nested iterable containing Tensors. Returns: dtype: inferred dtype of hidden state. Raises: ValueError: if `state` has heterogeneous dtypes or is empty. """ if explicit_dtype is not None: return explicit_dtype elif nest.is_sequence(state): inferred_dtypes = [element.dtype for element in nest.flatten(state)] if not inferred_dtypes: raise ValueError("Unable to infer dtype from empty state.") all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes]) if not all_same: raise ValueError( "State has tensors of different inferred_dtypes. Unable to infer a " "single representative dtype.") return inferred_dtypes[0] else: return state.dtype # pylint: disable=unused-argument