Python tensorflow.compat.v1.trainable_variables() Examples

The following are 30 code examples of tensorflow.compat.v1.trainable_variables(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v1 , or try the search function .
Example #1
Source File: variable_mgr_util.py    From benchmarks with Apache License 2.0 6 votes vote down vote up
def trainable_variables_on_device(self, rel_device_num, abs_device_num,
                                    writable):
    """Return the set of trainable variables on the specified device.

    Args:
      rel_device_num: local worker device index.
      abs_device_num: global graph device index.
      writable: whether the returned variables is writable or read-only.

    Returns:
      Return the set of trainable variables on the specified device.
    """
    del abs_device_num
    params_refs = tf.trainable_variables()
    if writable:
      return params_refs
    params = []
    for param in params_refs:
      var_name = param.name.split(':')[0]
      _, var_get_op = self.variable_mgr.staging_vars_on_devices[rel_device_num][
          var_name]
      params.append(var_get_op)
    return params 
Example #2
Source File: efficientnet_lite_builder_test.py    From Object_Detection_Tracking with Apache License 2.0 6 votes vote down vote up
def _test_model_params(self,
                         model_name,
                         input_size,
                         expected_params,
                         override_params=None,
                         features_only=False,
                         pooled_features_only=False):
    images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32)
    efficientnet_lite_builder.build_model(
        images,
        model_name=model_name,
        override_params=override_params,
        training=True,
        features_only=features_only,
        pooled_features_only=pooled_features_only)
    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])

    self.assertEqual(num_params, expected_params) 
Example #3
Source File: efficientnet_builder_test.py    From Object_Detection_Tracking with Apache License 2.0 6 votes vote down vote up
def _test_model_params(self,
                         model_name,
                         input_size,
                         expected_params,
                         override_params=None,
                         features_only=False,
                         pooled_features_only=False):
    images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32)
    efficientnet_builder.build_model(
        images,
        model_name=model_name,
        override_params=override_params,
        training=True,
        features_only=features_only,
        pooled_features_only=pooled_features_only)
    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    self.assertEqual(num_params, expected_params) 
Example #4
Source File: agent.py    From Hierarchical-Actor-Critc-HAC- with MIT License 6 votes vote down vote up
def initialize_networks(self):

        model_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(model_vars)

        # Set up directory for saving models
        self.model_dir = os.getcwd() + '/models'
        self.model_loc = self.model_dir + '/HAC.ckpt'

        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

         # Initialize actor/critic networks
        self.sess.run(tf.global_variables_initializer())

        # If not retraining, restore weights
        # if we are not retraining from scratch, just restore weights
        if self.FLAGS.retrain == False:
            self.saver.restore(self.sess, tf.train.latest_checkpoint(self.model_dir))


    # Save neural network parameters 
Example #5
Source File: agent.py    From Hierarchical-Actor-Critc-HAC- with MIT License 6 votes vote down vote up
def initialize_networks(self):

        model_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(model_vars)

        # Set up directory for saving models
        self.model_dir = os.getcwd() + '/models'
        self.model_loc = self.model_dir + '/HAC.ckpt'

        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

         # Initialize actor/critic networks
        self.sess.run(tf.global_variables_initializer())

        # If not retraining, restore weights
        # if we are not retraining from scratch, just restore weights
        if self.FLAGS.retrain == False:
            self.saver.restore(self.sess, tf.train.latest_checkpoint(self.model_dir))


    # Save neural network parameters 
Example #6
Source File: rnn_test.py    From magenta with Apache License 2.0 6 votes vote down vote up
def testCompatibleNames(self):
    with self.session(use_gpu=True, graph=tf.Graph()):
      cell = rnn_cell.LSTMCell(10)
      pcell = rnn_cell.LSTMCell(10, use_peepholes=True)
      inputs = [tf.zeros([4, 5])] * 6
      tf.nn.static_rnn(cell, inputs, dtype=tf.float32, scope="basic")
      tf.nn.static_rnn(pcell, inputs, dtype=tf.float32, scope="peephole")
      basic_names = {
          v.name: v.get_shape()
          for v in tf.trainable_variables()
      }

    with self.session(use_gpu=True, graph=tf.Graph()):
      cell = contrib_rnn.LSTMBlockCell(10)
      pcell = contrib_rnn.LSTMBlockCell(10, use_peephole=True)
      inputs = [tf.zeros([4, 5])] * 6
      tf.nn.static_rnn(cell, inputs, dtype=tf.float32, scope="basic")
      tf.nn.static_rnn(pcell, inputs, dtype=tf.float32, scope="peephole")
      block_names = {
          v.name: v.get_shape()
          for v in tf.trainable_variables()
      }

    self.assertEqual(basic_names, block_names) 
Example #7
Source File: agent.py    From Hierarchical-Actor-Critc-HAC- with MIT License 6 votes vote down vote up
def initialize_networks(self):

        model_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(model_vars)

        # Set up directory for saving models
        self.model_dir = os.getcwd() + '/models'
        self.model_loc = self.model_dir + '/HAC.ckpt'

        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

         # Initialize actor/critic networks
        self.sess.run(tf.global_variables_initializer())

        # If not retraining, restore weights
        # if we are not retraining from scratch, just restore weights
        if self.FLAGS.retrain == False:
            self.saver.restore(self.sess, tf.train.latest_checkpoint(self.model_dir))


    # Save neural network parameters 
Example #8
Source File: agent.py    From Hierarchical-Actor-Critc-HAC- with MIT License 6 votes vote down vote up
def initialize_networks(self):

        model_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(model_vars)

        # Set up directory for saving models
        self.model_dir = os.getcwd() + '/models'
        self.model_loc = self.model_dir + '/HAC.ckpt'

        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

         # Initialize actor/critic networks
        self.sess.run(tf.global_variables_initializer())

        # If not retraining, restore weights
        # if we are not retraining from scratch, just restore weights
        if self.FLAGS.retrain == False:
            self.saver.restore(self.sess, tf.train.latest_checkpoint(self.model_dir))


    # Save neural network parameters 
Example #9
Source File: optimization_test.py    From albert with Apache License 2.0 6 votes vote down vote up
def test_adam(self):
    with self.test_session() as sess:
      w = tf.get_variable(
          "w",
          shape=[3],
          initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
      x = tf.constant([0.4, 0.2, -0.5])
      loss = tf.reduce_mean(tf.square(x - w))
      tvars = tf.trainable_variables()
      grads = tf.gradients(loss, tvars)
      global_step = tf.train.get_or_create_global_step()
      optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
      train_op = optimizer.apply_gradients(list(zip(grads, tvars)), global_step)
      init_op = tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer())
      sess.run(init_op)
      for _ in range(100):
        sess.run(train_op)
      w_np = sess.run(w)
      self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 
Example #10
Source File: agent.py    From Hierarchical-Actor-Critc-HAC- with MIT License 6 votes vote down vote up
def initialize_networks(self):

        model_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(model_vars)

        # Set up directory for saving models
        self.model_dir = os.getcwd() + '/models'
        self.model_loc = self.model_dir + '/HAC.ckpt'

        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

         # Initialize actor/critic networks
        self.sess.run(tf.global_variables_initializer())

        # If not retraining, restore weights
        # if we are not retraining from scratch, just restore weights
        if self.FLAGS.retrain == False:
            self.saver.restore(self.sess, tf.train.latest_checkpoint(self.model_dir))


    # Save neural network parameters 
Example #11
Source File: maml_inner_loop_test.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def test_inner_loop_reuse(self, learn_inner_lr):
    # Inner loop should create as many trainable vars in 'inner_loop' scope as a
    # direct call to inference_network_fn would. Learned learning rates and
    # learned loss variables should be created *outside* the 'inner_loop' scope
    # since they do not adapt.
    graph = tf.Graph()
    with tf.Session(graph=graph):
      inputs = create_inputs()
      features, _ = inputs
      # Record how many trainable vars a call to inference_network_fn creates.
      with tf.variable_scope('test_scope'):
        inference_network_fn(features)
      expected_num_train_vars = len(tf.trainable_variables(scope='test_scope'))
      maml_inner_loop_instance = maml_inner_loop.MAMLInnerLoopGradientDescent(
          learning_rate=LEARNING_RATE, learn_inner_lr=learn_inner_lr)
      maml_inner_loop_instance.inner_loop(
          [inputs, inputs, inputs],
          inference_network_fn,
          learned_model_train_fn)
      num_train_vars = len(tf.trainable_variables(scope='inner_loop'))
      self.assertEqual(expected_num_train_vars, num_train_vars) 
Example #12
Source File: post_training_quantization.py    From models with Apache License 2.0 6 votes vote down vote up
def restore_model(sess, checkpoint_path, enable_ema=True):
  """Restore variables from the checkpoint into the provided session.

  Args:
    sess: A tensorflow session where the checkpoint will be loaded.
    checkpoint_path: Path to the trained checkpoint.
    enable_ema: (optional) Whether to load the exponential moving average (ema)
      version of the tensorflow variables. Defaults to True.
  """
  if enable_ema:
    ema = tf.train.ExponentialMovingAverage(decay=0.0)
    ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars")
    for v in tf.global_variables():
      if "moving_mean" in v.name or "moving_variance" in v.name:
        ema_vars.append(v)
    ema_vars = list(set(ema_vars))
    var_dict = ema.variables_to_restore(ema_vars)
  else:
    var_dict = None

  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(var_dict, max_to_keep=1)
  saver.restore(sess, checkpoint_path) 
Example #13
Source File: optimize.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def summarize_variables(var_list=None, tag=None):
  """Summarize the variables.

  Args:
    var_list: a list of variables; defaults to trainable_variables.
    tag: name scope of the summary; defaults to training_variables/.
  """
  if var_list is None:
    var_list = tf.trainable_variables()
  if tag is None:
    tag = "training_variables/"

  name_to_var = {v.name: v for v in var_list}
  for v_name in list(name_to_var):
    v = name_to_var[v_name]
    tf.summary.histogram(tag + v_name, v) 
Example #14
Source File: model_deploy_test.py    From models with Apache License 2.0 6 votes vote down vote up
def testCreateSingleclone(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = BatchNormClassifier
      clone_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=1)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
      self.assertEqual(len(slim.get_variables()), 5)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.assertEqual(len(update_ops), 2)

      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
                                                                optimizer)
      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
      self.assertEqual(total_loss.op.name, 'total_loss')
      for g, v in grads_and_vars:
        self.assertDeviceEqual(g.device, 'GPU:0')
        self.assertDeviceEqual(v.device, 'CPU:0') 
Example #15
Source File: variable_mgr.py    From benchmarks with Apache License 2.0 6 votes vote down vote up
def trainable_variables_on_device(self,
                                    rel_device_num,
                                    abs_device_num,
                                    writable=False):
    """Return the set of trainable variables on device.

    Args:
      rel_device_num: local worker device index.
      abs_device_num: global graph device index.
      writable: whether to get a reference to the underlying variable.

    Returns:
      The set of trainable variables on the specified device.
    """
    del rel_device_num, writable
    if self.each_tower_has_variables():
      params = [
          v for v in tf.trainable_variables()
          if v.name.startswith('v%s/' % abs_device_num)
      ]
    else:
      params = tf.trainable_variables()
    return params 
Example #16
Source File: optimize.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):
  """Apply weight decay and weight noise."""
  if var_list is None:
    var_list = tf.trainable_variables()

  decay_vars = [v for v in var_list]
  noise_vars = [v for v in var_list if "/body/" in v.name]

  weight_decay_loss = weight_decay(hparams.weight_decay, decay_vars)
  if hparams.weight_decay and common_layers.should_generate_summaries():
    tf.summary.scalar("losses/weight_decay", weight_decay_loss)
  weight_noise_ops = weight_noise(hparams.weight_noise, learning_rate,
                                  noise_vars)

  with tf.control_dependencies(weight_noise_ops):
    loss = tf.identity(loss)

  loss += weight_decay_loss
  return loss 
Example #17
Source File: training.py    From lamb with Apache License 2.0 6 votes vote down vote up
def _load_checkpoint(checkpoint_filename, extra_vars, trainable_only=False):
  if tf.gfile.IsDirectory(checkpoint_filename):
    checkpoint_filename = tf.train.latest_checkpoint(checkpoint_filename)
  logging.info('Loading checkpoint %s', checkpoint_filename)
  saveables = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
               tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
  if trainable_only:
    saveables = list(set(saveables) & set(tf.trainable_variables()))
  # Try to restore all saveables, if that fails try without extra_vars.
  try:
    saver = tf.train.Saver(var_list=saveables)
    saver.restore(tf.get_default_session(), checkpoint_filename)
  except (ValueError, tf.errors.NotFoundError):
    logging.info('Missing key in checkpoint. Trying old checkpoint format.')
    saver = tf.train.Saver(var_list=list(set(saveables) - set(extra_vars)))
    saver.restore(tf.get_default_session(), checkpoint_filename) 
Example #18
Source File: model_deploy_test.py    From models with Apache License 2.0 6 votes vote down vote up
def testCreateLogisticClassifier(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = LogisticClassifier
      clone_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=1)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
      self.assertEqual(len(slim.get_variables()), 2)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.assertEqual(update_ops, [])

      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
                                                                optimizer)
      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
      self.assertEqual(total_loss.op.name, 'total_loss')
      for g, v in grads_and_vars:
        self.assertDeviceEqual(g.device, 'GPU:0')
        self.assertDeviceEqual(v.device, 'CPU:0') 
Example #19
Source File: variable_mgr.py    From benchmarks with Apache License 2.0 6 votes vote down vote up
def savable_variables(self):
    """Returns a list/dict of savable variables to pass to tf.train.Saver."""
    params = {}
    for v in tf.global_variables():
      assert (v.name.startswith(variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/v0/')
              or v.name in ('global_step:0', 'loss_scale:0',
                            'loss_scale_normal_steps:0')), (
                                'Invalid global variable: %s' % v)
      # We store variables in the checkpoint with the shadow variable prefix
      # removed so we can evaluate checkpoints in non-distributed replicated
      # mode. The checkpoints can also be loaded for training in
      # distributed_replicated mode.
      name = self._strip_port(self._remove_shadow_var_prefix_if_present(v.name))
      params[name] = v
    for v in tf.local_variables():
      # Non-trainable variables, such as batch norm moving averages, do not have
      # corresponding global shadow variables, so we add them here. Trainable
      # local variables have corresponding global shadow variables, which were
      # added in the global variable loop above.
      if v.name.startswith('v0/') and v not in tf.trainable_variables():
        params[self._strip_port(v.name)] = v
    return params 
Example #20
Source File: utils.py    From lamb with Apache License 2.0 6 votes vote down vote up
def find_var(name, vars_=None):
  """Find a variable by name or return None.

  Args:
    name: The name of the variable (full qualified with all
      enclosing scopes).
    vars_: The variables among which to search. Defaults to all
      trainable variables.

  Returns:
    The [first] variable with `name` among `vars_` or None if there
    is no match.
  """
  if vars_ is None:
    vars_ = tf.trainable_variables()
  return next((var for var in vars_ if var.name == name),
              None) 
Example #21
Source File: train_image_classifier.py    From models with Apache License 2.0 6 votes vote down vote up
def _get_variables_to_train():
  """Returns a list of variables to train.

  Returns:
    A list of variables to train by the optimizer.
  """
  if FLAGS.trainable_scopes is None:
    return tf.trainable_variables()
  else:
    scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]

  variables_to_train = []
  for scope in scopes:
    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
    variables_to_train.extend(variables)
  return variables_to_train 
Example #22
Source File: discriminative_eval.py    From language with Apache License 2.0 5 votes vote down vote up
def _restore_checkpoint(init_checkpoint):
  """Restore parameters from checkpoint."""
  tvars = tf.trainable_variables()
  (assignment_map,
   _) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
  tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 
Example #23
Source File: pnasnet_test.py    From models with Apache License 2.0 5 votes vote down vote up
def testBuildNonExistingLayerMobileModel(self):
    """Tests that the model is built correctly without unnecessary layers."""
    inputs = tf.random.uniform((5, 224, 224, 3))
    tf.train.create_global_step()
    with slim.arg_scope(pnasnet.pnasnet_mobile_arg_scope()):
      pnasnet.build_pnasnet_mobile(inputs, 1000)
    vars_names = [x.op.name for x in tf.trainable_variables()]
    self.assertIn('cell_stem_0/1x1/weights', vars_names)
    self.assertNotIn('cell_stem_1/comb_iter_0/right/1x1/weights', vars_names) 
Example #24
Source File: model_deploy_test.py    From models with Apache License 2.0 5 votes vote down vote up
def testCreateMulticloneCPU(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = BatchNormClassifier
      model_args = (tf_inputs, tf_labels)
      num_clones = 4
      deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones,
                                                    clone_on_cpu=True)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
      self.assertEqual(len(slim.get_variables()), 5)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.assertEqual(len(update_ops), num_clones * 2)

      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
                                                                optimizer)
      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
      self.assertEqual(total_loss.op.name, 'total_loss')
      for g, v in grads_and_vars:
        self.assertDeviceEqual(g.device, '')
        self.assertDeviceEqual(v.device, 'CPU:0') 
Example #25
Source File: all_reduce_benchmark.py    From benchmarks with Apache License 2.0 5 votes vote down vote up
def get_var_shapes(model):
  """Returns the list of variable shapes for a tf_cnn_benchmarks Model."""
  with tf.Graph().as_default():
    # The variable shapes do not depend on the batch size.
    images = tf.placeholder(tf.float32, model.get_input_shapes('train')[0])
    model.build_network([images])
    return [[int(d) for d in v.shape.dims] for v in tf.trainable_variables()] 
Example #26
Source File: model_executor.py    From mesh with Apache License 2.0 5 votes vote down vote up
def _print_variable_values(sess):
  """May give `Protocol buffer too large` error."""
  np.set_printoptions(precision=4, linewidth=1000)
  tf.logging.info('Printing variables.')
  tf.logging.info('===================')
  values = sess.run(tf.trainable_variables())
  for variable, value in zip(tf.trainable_variables(), values):
    tf.logging.info('{}, {}'.format(variable.name, value.shape))
    tf.logging.info('{}'.format(np.array(value).flatten())) 
Example #27
Source File: run_recurrent_model_boolq.py    From language with Apache License 2.0 5 votes vote down vote up
def evaluate():
  """Evaluate a model on the dev set."""
  sess = tf.Session()
  tf.logging.info("Building graph...")

  embeddings = load_embeddings()
  tf_data = load_batched_dataset(False, embeddings)
  it = tf_data.make_initializable_iterator()
  features, labels = it.get_next()

  logits = predict(False, embeddings, features["premise"],
                   features["hypothesis"])
  accuracy, update_ops = tf.metrics.accuracy(
      tf.argmax(logits, 1, output_type=tf.int32), tf.to_int32(labels))

  tf.logging.info("Running initializers...")
  checkpoint_file = FLAGS.checkpoint_file
  if checkpoint_file is not None:
    saver = tf.train.Saver(tf.trainable_variables())
    tf.logging.info("Restoring from checkpoint: " + checkpoint_file)
    saver.restore(sess, checkpoint_file)
  else:
    tf.logging.warning("No checkpoint given, evaling model with random weights")
    sess.run(tf.global_variables_initializer())
  sess.run(tf.local_variables_initializer())
  sess.run(tf.tables_initializer())
  sess.run(it.initializer)

  tf.logging.info("Starting loop....")
  while True:
    try:
      sess.run(update_ops)
    except tf.errors.OutOfRangeError:
      break
  tf.logging.info("Done")

  accuracy = sess.run(accuracy)
  print("Accuracy: %f" % accuracy) 
Example #28
Source File: search_utils.py    From language with Apache License 2.0 5 votes vote down vote up
def init_from_checkpoint(checkpoint_path,
                         checkpoint_prefix=None,
                         variable_prefix=None,
                         target_variables=None):
  """Initializes all of the variables using `init_checkpoint."""
  tf.logging.info("Loading variables from %s", checkpoint_path)
  checkpoint_variables = {
      name: name for name, _ in tf.train.list_variables(checkpoint_path)
  }
  if target_variables is None:
    target_variables = tf.trainable_variables()
  target_variables = {var.name.split(":")[0]: var for var in target_variables}

  if checkpoint_prefix is not None:
    checkpoint_variables = {
        checkpoint_prefix + "/" + name: varname
        for name, varname in checkpoint_variables.items()
    }
  if variable_prefix is not None:
    target_variables = {
        variable_prefix + "/" + name: var
        for name, var in target_variables.items()
    }

  checkpoint_var_names = set(checkpoint_variables.keys())
  target_var_names = set(target_variables.keys())
  intersected_var_names = target_var_names & checkpoint_var_names
  assignment_map = {
      checkpoint_variables[name]: target_variables[name]
      for name in intersected_var_names
  }
  tf.train.init_from_checkpoint(checkpoint_path, assignment_map)

  log_variables("Loaded variables", intersected_var_names)
  log_variables("Uninitialized variables",
                target_var_names - checkpoint_var_names)
  log_variables("Unused variables", checkpoint_var_names - target_var_names) 
Example #29
Source File: load_from_checkpoint.py    From language with Apache License 2.0 5 votes vote down vote up
def init_model_from_checkpoint(checkpoint_dir,
                               use_tpu=False,
                               checkpoint_file=None,
                               reinitialize_type_embeddings=False):
  """Initializes whitelisted parameters from pretrained checkpoint dir.

  Args:
    checkpoint_dir: Path to the checkpoint dir.
    use_tpu: Whether to use TPU to train.
    checkpoint_file: Name of the checkpoint file.
    reinitialize_type_embeddings: Whether to re-initialize the type embeddings
      used in the BERT model.

  Returns:
    Dictionary of whitelisted pretrained parameter names if warm_start_whitelist
    is set and scaffold_fn if use tpu.
  """
  if checkpoint_file:
    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
  else:
    checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
  (assignment_map,
   initialized_variable_names) = _get_assignment_map_from_checkpoint(
       tf.trainable_variables(), checkpoint_path, reinitialize_type_embeddings)
  tf.logging.info("Pretrained parameter assignment_map: %s", assignment_map)
  scaffold_fn = None
  # We have to pass scaffold_fn to TPUEstimatorSpec for initing checkpoint from
  # Bert otherwise it will fail to init TPU system.
  if use_tpu:

    def tpu_scaffold():
      tf.train.init_from_checkpoint(checkpoint_path, assignment_map)
      return tf.train.Scaffold()

    scaffold_fn = tpu_scaffold
  else:
    tf.train.init_from_checkpoint(checkpoint_path, assignment_map)
  return initialized_variable_names, scaffold_fn 
Example #30
Source File: coherence_eval.py    From language with Apache License 2.0 5 votes vote down vote up
def _restore_checkpoint(init_checkpoint):
  """Restore parameters from checkpoint."""
  tvars = tf.trainable_variables()
  (assignment_map,
   _) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
  tf.train.init_from_checkpoint(init_checkpoint, assignment_map)