Python tensorflow.contrib.rnn.ResidualWrapper() Examples
The following are 5
code examples of tensorflow.contrib.rnn.ResidualWrapper().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.contrib.rnn
, or try the search function
.
Example #1
Source File: seq2seq.py From retrosynthesis_planner with GNU General Public License v3.0 | 5 votes |
def _make_cell(self, hidden_size=None): """Create a single RNN cell""" cell = self.cell_type(hidden_size or self.hidden_size) if self.dropout: cell = rnn.DropoutWrapper(cell, self.keep_prob) if self.residual: cell = rnn.ResidualWrapper(cell) return cell
Example #2
Source File: model_seq2seq.py From nlp_xiaojiang with MIT License | 5 votes |
def build_single_cell(self, n_hidden, use_residual): """构建一个单独的rnn cell Args: n_hidden: 隐藏层神经元数量 use_residual: 是否使用residual wrapper """ if self.cell_type == 'gru': cell_type = GRUCell else: cell_type = LSTMCell cell = cell_type(n_hidden) if self.use_dropout: cell = DropoutWrapper( cell, dtype=tf.float32, output_keep_prob=self.keep_prob_placeholder, seed=self.seed ) if use_residual: cell = ResidualWrapper(cell) return cell
Example #3
Source File: rnn_factory.py From THRED with MIT License | 5 votes |
def create_cell(unit_type, hidden_units, num_layers, use_residual=False, input_keep_prob=1.0, output_keep_prob=1.0, devices=None): if unit_type == 'lstm': def _new_cell(): return tf.nn.rnn_cell.BasicLSTMCell(hidden_units) elif unit_type == 'gru': def _new_cell(): return tf.contrib.rnn.GRUCell(hidden_units) else: raise ValueError('cell_type must be either lstm or gru') def _new_cell_wrapper(residual_connection=False, device_id=None): c = _new_cell() if input_keep_prob < 1.0 or output_keep_prob < 1.0: c = rnn.DropoutWrapper(c, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob) if residual_connection: c = rnn.ResidualWrapper(c) if device_id: c = rnn.DeviceWrapper(c, device_id) return c if num_layers > 1: cells = [] for i in range(num_layers): is_residual = True if use_residual and i > 0 else False cells.append(_new_cell_wrapper(is_residual, devices[i] if devices else None)) return tf.contrib.rnn.MultiRNNCell(cells) else: return _new_cell_wrapper(device_id=devices[0] if devices else None)
Example #4
Source File: model_utils.py From language with Apache License 2.0 | 5 votes |
def _single_cell(unit_type, num_units, forget_bias, dropout, mode, residual_connection=False, residual_fn=None, trainable=True): """Create an instance of a single RNN cell.""" # dropout (= 1 - keep_prob) is set to 0 during eval and infer dropout = dropout if mode == tf.estimator.ModeKeys.TRAIN else 0.0 # Cell Type if unit_type == "lstm": single_cell = contrib_rnn.LSTMCell( num_units, forget_bias=forget_bias, trainable=trainable) elif unit_type == "gru": single_cell = contrib_rnn.GRUCell(num_units, trainable=trainable) elif unit_type == "layer_norm_lstm": single_cell = contrib_rnn.LayerNormBasicLSTMCell( num_units, forget_bias=forget_bias, layer_norm=True, trainable=trainable) elif unit_type == "nas": single_cell = contrib_rnn.NASCell(num_units, trainable=trainable) else: raise ValueError("Unknown unit type %s!" % unit_type) # Dropout (= 1 - keep_prob). if dropout > 0.0: single_cell = contrib_rnn.DropoutWrapper( cell=single_cell, input_keep_prob=(1.0 - dropout)) # Residual. if residual_connection: single_cell = contrib_rnn.ResidualWrapper( single_cell, residual_fn=residual_fn) return single_cell
Example #5
Source File: lstm_utils.py From synvae with MIT License | 5 votes |
def rnn_cell(rnn_cell_size, dropout_keep_prob, residual, is_training=True): """Builds an LSTMBlockCell based on the given parameters.""" dropout_keep_prob = dropout_keep_prob if is_training else 1.0 cells = [] for i in range(len(rnn_cell_size)): cell = rnn.LSTMBlockCell(rnn_cell_size[i]) if residual: cell = rnn.ResidualWrapper(cell) if i == 0 or rnn_cell_size[i] != rnn_cell_size[i - 1]: cell = rnn.InputProjectionWrapper(cell, rnn_cell_size[i]) cell = rnn.DropoutWrapper( cell, input_keep_prob=dropout_keep_prob) cells.append(cell) return rnn.MultiRNNCell(cells)