Python sonnet.get_variables_in_module() Examples

The following are 30 code examples of sonnet.get_variables_in_module(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module sonnet , or try the search function .
Example #1
Source File: networks.py    From learning-to-learn with Apache License 2.0 6 votes vote down vote up
def save(network, sess, filename=None):
  """Save the variables contained by a network to disk."""
  to_save = collections.defaultdict(dict)
  variables = snt.get_variables_in_module(network)

  for v in variables:
    split = v.name.split(":")[0].split("/")
    module_name = split[-2]
    variable_name = split[-1]
    to_save[module_name][variable_name] = v.eval(sess)

  if filename:
    with open(filename, "wb") as f:
      pickle.dump(to_save, f)

  return to_save 
Example #2
Source File: fasterrcnn.py    From luminoth with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_trainable_vars(self):
        """Get trainable vars included in the module.
        """
        trainable_vars = snt.get_variables_in_module(self)
        if self._config.model.base_network.trainable:
            pretrained_trainable_vars = self.base_network.get_trainable_vars()
            if len(pretrained_trainable_vars):
                tf.logging.info(
                    'Training {} vars from pretrained module; '
                    'from "{}" to "{}".'.format(
                        len(pretrained_trainable_vars),
                        pretrained_trainable_vars[0].name,
                        pretrained_trainable_vars[-1].name,
                    )
                )
            else:
                tf.logging.info('No vars from pretrained module to train.')
            trainable_vars += pretrained_trainable_vars
        else:
            tf.logging.info('Not training variables from pretrained module')

        return trainable_vars 
Example #3
Source File: fasterrcnn.py    From Tabulo with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_trainable_vars(self):
        """Get trainable vars included in the module.
        """
        trainable_vars = snt.get_variables_in_module(self)
        if self._config.model.base_network.trainable:
            pretrained_trainable_vars = self.base_network.get_trainable_vars()
            if len(pretrained_trainable_vars):
                tf.logging.info(
                    'Training {} vars from pretrained module; '
                    'from "{}" to "{}".'.format(
                        len(pretrained_trainable_vars),
                        pretrained_trainable_vars[0].name,
                        pretrained_trainable_vars[-1].name,
                    )
                )
            else:
                tf.logging.info('No vars from pretrained module to train.')
            trainable_vars += pretrained_trainable_vars
        else:
            tf.logging.info('Not training variables from pretrained module')

        return trainable_vars 
Example #4
Source File: fasterrcnn.py    From Table-Detection-using-Deep-learning with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_trainable_vars(self):
        """Get trainable vars included in the module.
        """
        trainable_vars = snt.get_variables_in_module(self)
        if self._config.model.base_network.trainable:
            pretrained_trainable_vars = self.base_network.get_trainable_vars()
            if len(pretrained_trainable_vars):
                tf.logging.info(
                    'Training {} vars from pretrained module; '
                    'from "{}" to "{}".'.format(
                        len(pretrained_trainable_vars),
                        pretrained_trainable_vars[0].name,
                        pretrained_trainable_vars[-1].name,
                    )
                )
            else:
                tf.logging.info('No vars from pretrained module to train.')
            trainable_vars += pretrained_trainable_vars
        else:
            tf.logging.info('Not training variables from pretrained module')

        return trainable_vars 
Example #5
Source File: common.py    From models with Apache License 2.0 5 votes vote down vote up
def w(self):
    var_list = snt.get_variables_in_module(self)
    if self.use_bias:
      assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
    else:
      assert len(var_list) == 1, "Found not 1 but %d" % len(var_list)
    w = [x for x in var_list if self._raw_name(x.name) == "w"]
    assert len(w) == 1
    return w[0] 
Example #6
Source File: common.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def w(self):
    var_list = snt.get_variables_in_module(self)
    w = [x for x in var_list if self._raw_name(x.name) == "w"]
    assert len(w) == 1
    return w[0] 
Example #7
Source File: common.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def b(self):
    var_list = snt.get_variables_in_module(self)
    b = [x for x in var_list if self._raw_name(x.name) == "b"]
    assert len(b) == 1
    return b[0] 
Example #8
Source File: common.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def w(self):
    var_list = snt.get_variables_in_module(self)
    if self.use_bias:
      assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
    else:
      assert len(var_list) == 1, "Found not 1 but %d" % len(var_list)
    w = [x for x in var_list if self._raw_name(x.name) == "w"]
    assert len(w) == 1
    return w[0] 
Example #9
Source File: common.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def b(self):
    var_list = snt.get_variables_in_module(self)
    assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
    b = [x for x in var_list if self._raw_name(x.name) == "b"]
    assert len(b) == 1
    return b[0] 
Example #10
Source File: more_local_weight_update.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def remote_variables(self):
    train = list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
    train += list(
        snt.get_variables_in_module(self,
                                    tf.GraphKeys.MOVING_AVERAGE_VARIABLES))
    return train 
Example #11
Source File: linear_regression.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def local_variables(self):
    """List of variables that need to be updated for each evaluation.

    These variables should not be stored on a parameter server and
    should be reset every computation of a meta_objective loss.

    Returns:
      vars: list of tf.Variable
    """
    return list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES)) 
Example #12
Source File: utils.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def get_variables_in_modules(module_list):
  var_list = []
  for m in module_list:
    var_list.extend(snt.get_variables_in_module(m))
  return var_list 
Example #13
Source File: common.py    From models with Apache License 2.0 5 votes vote down vote up
def w(self):
    var_list = snt.get_variables_in_module(self)
    w = [x for x in var_list if self._raw_name(x.name) == "w"]
    assert len(w) == 1
    return w[0] 
Example #14
Source File: common.py    From models with Apache License 2.0 5 votes vote down vote up
def b(self):
    var_list = snt.get_variables_in_module(self)
    b = [x for x in var_list if self._raw_name(x.name) == "b"]
    assert len(b) == 1
    return b[0] 
Example #15
Source File: networks_test.py    From learning-to-learn with Apache License 2.0 5 votes vote down vote up
def testTrainable(self):
    """Tests the network contains trainable variables."""
    kernel_shape = [5, 5]
    shape = kernel_shape + [2, 2]  # The input has to be 4-dimensional.
    gradients = tf.random_normal(shape)
    net = networks.KernelDeepLSTM(layers=(1,), kernel_shape=kernel_shape)
    state = net.initial_state_for_inputs(gradients)
    net(gradients, state)
    # Weights and biases for two layers.
    variables = snt.get_variables_in_module(net)
    self.assertEqual(len(variables), 4) 
Example #16
Source File: common.py    From models with Apache License 2.0 5 votes vote down vote up
def b(self):
    var_list = snt.get_variables_in_module(self)
    assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
    b = [x for x in var_list if self._raw_name(x.name) == "b"]
    assert len(b) == 1
    return b[0] 
Example #17
Source File: more_local_weight_update.py    From models with Apache License 2.0 5 votes vote down vote up
def remote_variables(self):
    train = list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
    train += list(
        snt.get_variables_in_module(self,
                                    tf.GraphKeys.MOVING_AVERAGE_VARIABLES))
    return train 
Example #18
Source File: linear_regression.py    From models with Apache License 2.0 5 votes vote down vote up
def local_variables(self):
    """List of variables that need to be updated for each evaluation.

    These variables should not be stored on a parameter server and
    should be reset every computation of a meta_objective loss.

    Returns:
      vars: list of tf.Variable
    """
    return list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES)) 
Example #19
Source File: utils.py    From models with Apache License 2.0 5 votes vote down vote up
def get_variables_in_modules(module_list):
  var_list = []
  for m in module_list:
    var_list.extend(snt.get_variables_in_module(m))
  return var_list 
Example #20
Source File: common.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def w(self):
    var_list = snt.get_variables_in_module(self)
    w = [x for x in var_list if self._raw_name(x.name) == "w"]
    assert len(w) == 1
    return w[0] 
Example #21
Source File: common.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def b(self):
    var_list = snt.get_variables_in_module(self)
    b = [x for x in var_list if self._raw_name(x.name) == "b"]
    assert len(b) == 1
    return b[0] 
Example #22
Source File: common.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def w(self):
    var_list = snt.get_variables_in_module(self)
    if self.use_bias:
      assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
    else:
      assert len(var_list) == 1, "Found not 1 but %d" % len(var_list)
    w = [x for x in var_list if self._raw_name(x.name) == "w"]
    assert len(w) == 1
    return w[0] 
Example #23
Source File: common.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def b(self):
    var_list = snt.get_variables_in_module(self)
    assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
    b = [x for x in var_list if self._raw_name(x.name) == "b"]
    assert len(b) == 1
    return b[0] 
Example #24
Source File: more_local_weight_update.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def remote_variables(self):
    train = list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
    train += list(
        snt.get_variables_in_module(self,
                                    tf.GraphKeys.MOVING_AVERAGE_VARIABLES))
    return train 
Example #25
Source File: linear_regression.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def local_variables(self):
    """List of variables that need to be updated for each evaluation.

    These variables should not be stored on a parameter server and
    should be reset every computation of a meta_objective loss.

    Returns:
      vars: list of tf.Variable
    """
    return list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES)) 
Example #26
Source File: utils.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def get_variables_in_modules(module_list):
  var_list = []
  for m in module_list:
    var_list.extend(snt.get_variables_in_module(m))
  return var_list 
Example #27
Source File: base_network.py    From Tabulo with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _get_base_network_vars(self):
        """Returns a list of all the base network's variables."""
        if self.pretrained_weights_scope:
            # We may have defined the base network in a particular scope
            module_variables = tf.get_collection(
                tf.GraphKeys.MODEL_VARIABLES,
                scope=self.pretrained_weights_scope
            )
        else:
            module_variables = snt.get_variables_in_module(
                self, tf.GraphKeys.MODEL_VARIABLES
            )
        assert len(module_variables) > 0
        return module_variables 
Example #28
Source File: common.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def b(self):
    var_list = snt.get_variables_in_module(self)
    b = [x for x in var_list if self._raw_name(x.name) == "b"]
    assert len(b) == 1
    return b[0] 
Example #29
Source File: common.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def w(self):
    var_list = snt.get_variables_in_module(self)
    if self.use_bias:
      assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
    else:
      assert len(var_list) == 1, "Found not 1 but %d" % len(var_list)
    w = [x for x in var_list if self._raw_name(x.name) == "w"]
    assert len(w) == 1
    return w[0] 
Example #30
Source File: common.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def b(self):
    var_list = snt.get_variables_in_module(self)
    assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
    b = [x for x in var_list if self._raw_name(x.name) == "b"]
    assert len(b) == 1
    return b[0]