Python sonnet.get_variables_in_module() Examples
The following are 30
code examples of sonnet.get_variables_in_module().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
sonnet
, or try the search function
.
Example #1
Source File: networks.py From learning-to-learn with Apache License 2.0 | 6 votes |
def save(network, sess, filename=None): """Save the variables contained by a network to disk.""" to_save = collections.defaultdict(dict) variables = snt.get_variables_in_module(network) for v in variables: split = v.name.split(":")[0].split("/") module_name = split[-2] variable_name = split[-1] to_save[module_name][variable_name] = v.eval(sess) if filename: with open(filename, "wb") as f: pickle.dump(to_save, f) return to_save
Example #2
Source File: fasterrcnn.py From luminoth with BSD 3-Clause "New" or "Revised" License | 6 votes |
def get_trainable_vars(self): """Get trainable vars included in the module. """ trainable_vars = snt.get_variables_in_module(self) if self._config.model.base_network.trainable: pretrained_trainable_vars = self.base_network.get_trainable_vars() if len(pretrained_trainable_vars): tf.logging.info( 'Training {} vars from pretrained module; ' 'from "{}" to "{}".'.format( len(pretrained_trainable_vars), pretrained_trainable_vars[0].name, pretrained_trainable_vars[-1].name, ) ) else: tf.logging.info('No vars from pretrained module to train.') trainable_vars += pretrained_trainable_vars else: tf.logging.info('Not training variables from pretrained module') return trainable_vars
Example #3
Source File: fasterrcnn.py From Tabulo with BSD 3-Clause "New" or "Revised" License | 6 votes |
def get_trainable_vars(self): """Get trainable vars included in the module. """ trainable_vars = snt.get_variables_in_module(self) if self._config.model.base_network.trainable: pretrained_trainable_vars = self.base_network.get_trainable_vars() if len(pretrained_trainable_vars): tf.logging.info( 'Training {} vars from pretrained module; ' 'from "{}" to "{}".'.format( len(pretrained_trainable_vars), pretrained_trainable_vars[0].name, pretrained_trainable_vars[-1].name, ) ) else: tf.logging.info('No vars from pretrained module to train.') trainable_vars += pretrained_trainable_vars else: tf.logging.info('Not training variables from pretrained module') return trainable_vars
Example #4
Source File: fasterrcnn.py From Table-Detection-using-Deep-learning with BSD 3-Clause "New" or "Revised" License | 6 votes |
def get_trainable_vars(self): """Get trainable vars included in the module. """ trainable_vars = snt.get_variables_in_module(self) if self._config.model.base_network.trainable: pretrained_trainable_vars = self.base_network.get_trainable_vars() if len(pretrained_trainable_vars): tf.logging.info( 'Training {} vars from pretrained module; ' 'from "{}" to "{}".'.format( len(pretrained_trainable_vars), pretrained_trainable_vars[0].name, pretrained_trainable_vars[-1].name, ) ) else: tf.logging.info('No vars from pretrained module to train.') trainable_vars += pretrained_trainable_vars else: tf.logging.info('Not training variables from pretrained module') return trainable_vars
Example #5
Source File: common.py From models with Apache License 2.0 | 5 votes |
def w(self): var_list = snt.get_variables_in_module(self) if self.use_bias: assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) else: assert len(var_list) == 1, "Found not 1 but %d" % len(var_list) w = [x for x in var_list if self._raw_name(x.name) == "w"] assert len(w) == 1 return w[0]
Example #6
Source File: common.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def w(self): var_list = snt.get_variables_in_module(self) w = [x for x in var_list if self._raw_name(x.name) == "w"] assert len(w) == 1 return w[0]
Example #7
Source File: common.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def b(self): var_list = snt.get_variables_in_module(self) b = [x for x in var_list if self._raw_name(x.name) == "b"] assert len(b) == 1 return b[0]
Example #8
Source File: common.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def w(self): var_list = snt.get_variables_in_module(self) if self.use_bias: assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) else: assert len(var_list) == 1, "Found not 1 but %d" % len(var_list) w = [x for x in var_list if self._raw_name(x.name) == "w"] assert len(w) == 1 return w[0]
Example #9
Source File: common.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def b(self): var_list = snt.get_variables_in_module(self) assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) b = [x for x in var_list if self._raw_name(x.name) == "b"] assert len(b) == 1 return b[0]
Example #10
Source File: more_local_weight_update.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def remote_variables(self): train = list( snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES)) train += list( snt.get_variables_in_module(self, tf.GraphKeys.MOVING_AVERAGE_VARIABLES)) return train
Example #11
Source File: linear_regression.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def local_variables(self): """List of variables that need to be updated for each evaluation. These variables should not be stored on a parameter server and should be reset every computation of a meta_objective loss. Returns: vars: list of tf.Variable """ return list( snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
Example #12
Source File: utils.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def get_variables_in_modules(module_list): var_list = [] for m in module_list: var_list.extend(snt.get_variables_in_module(m)) return var_list
Example #13
Source File: common.py From models with Apache License 2.0 | 5 votes |
def w(self): var_list = snt.get_variables_in_module(self) w = [x for x in var_list if self._raw_name(x.name) == "w"] assert len(w) == 1 return w[0]
Example #14
Source File: common.py From models with Apache License 2.0 | 5 votes |
def b(self): var_list = snt.get_variables_in_module(self) b = [x for x in var_list if self._raw_name(x.name) == "b"] assert len(b) == 1 return b[0]
Example #15
Source File: networks_test.py From learning-to-learn with Apache License 2.0 | 5 votes |
def testTrainable(self): """Tests the network contains trainable variables.""" kernel_shape = [5, 5] shape = kernel_shape + [2, 2] # The input has to be 4-dimensional. gradients = tf.random_normal(shape) net = networks.KernelDeepLSTM(layers=(1,), kernel_shape=kernel_shape) state = net.initial_state_for_inputs(gradients) net(gradients, state) # Weights and biases for two layers. variables = snt.get_variables_in_module(net) self.assertEqual(len(variables), 4)
Example #16
Source File: common.py From models with Apache License 2.0 | 5 votes |
def b(self): var_list = snt.get_variables_in_module(self) assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) b = [x for x in var_list if self._raw_name(x.name) == "b"] assert len(b) == 1 return b[0]
Example #17
Source File: more_local_weight_update.py From models with Apache License 2.0 | 5 votes |
def remote_variables(self): train = list( snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES)) train += list( snt.get_variables_in_module(self, tf.GraphKeys.MOVING_AVERAGE_VARIABLES)) return train
Example #18
Source File: linear_regression.py From models with Apache License 2.0 | 5 votes |
def local_variables(self): """List of variables that need to be updated for each evaluation. These variables should not be stored on a parameter server and should be reset every computation of a meta_objective loss. Returns: vars: list of tf.Variable """ return list( snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
Example #19
Source File: utils.py From models with Apache License 2.0 | 5 votes |
def get_variables_in_modules(module_list): var_list = [] for m in module_list: var_list.extend(snt.get_variables_in_module(m)) return var_list
Example #20
Source File: common.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def w(self): var_list = snt.get_variables_in_module(self) w = [x for x in var_list if self._raw_name(x.name) == "w"] assert len(w) == 1 return w[0]
Example #21
Source File: common.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def b(self): var_list = snt.get_variables_in_module(self) b = [x for x in var_list if self._raw_name(x.name) == "b"] assert len(b) == 1 return b[0]
Example #22
Source File: common.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def w(self): var_list = snt.get_variables_in_module(self) if self.use_bias: assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) else: assert len(var_list) == 1, "Found not 1 but %d" % len(var_list) w = [x for x in var_list if self._raw_name(x.name) == "w"] assert len(w) == 1 return w[0]
Example #23
Source File: common.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def b(self): var_list = snt.get_variables_in_module(self) assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) b = [x for x in var_list if self._raw_name(x.name) == "b"] assert len(b) == 1 return b[0]
Example #24
Source File: more_local_weight_update.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def remote_variables(self): train = list( snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES)) train += list( snt.get_variables_in_module(self, tf.GraphKeys.MOVING_AVERAGE_VARIABLES)) return train
Example #25
Source File: linear_regression.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def local_variables(self): """List of variables that need to be updated for each evaluation. These variables should not be stored on a parameter server and should be reset every computation of a meta_objective loss. Returns: vars: list of tf.Variable """ return list( snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
Example #26
Source File: utils.py From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def get_variables_in_modules(module_list): var_list = [] for m in module_list: var_list.extend(snt.get_variables_in_module(m)) return var_list
Example #27
Source File: base_network.py From Tabulo with BSD 3-Clause "New" or "Revised" License | 5 votes |
def _get_base_network_vars(self): """Returns a list of all the base network's variables.""" if self.pretrained_weights_scope: # We may have defined the base network in a particular scope module_variables = tf.get_collection( tf.GraphKeys.MODEL_VARIABLES, scope=self.pretrained_weights_scope ) else: module_variables = snt.get_variables_in_module( self, tf.GraphKeys.MODEL_VARIABLES ) assert len(module_variables) > 0 return module_variables
Example #28
Source File: common.py From Gun-Detector with Apache License 2.0 | 5 votes |
def b(self): var_list = snt.get_variables_in_module(self) b = [x for x in var_list if self._raw_name(x.name) == "b"] assert len(b) == 1 return b[0]
Example #29
Source File: common.py From Gun-Detector with Apache License 2.0 | 5 votes |
def w(self): var_list = snt.get_variables_in_module(self) if self.use_bias: assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) else: assert len(var_list) == 1, "Found not 1 but %d" % len(var_list) w = [x for x in var_list if self._raw_name(x.name) == "w"] assert len(w) == 1 return w[0]
Example #30
Source File: common.py From Gun-Detector with Apache License 2.0 | 5 votes |
def b(self): var_list = snt.get_variables_in_module(self) assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) b = [x for x in var_list if self._raw_name(x.name) == "b"] assert len(b) == 1 return b[0]