Python tensorflow.contrib.slim.assign_from_checkpoint_fn() Examples

The following are 26 code examples of tensorflow.contrib.slim.assign_from_checkpoint_fn(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.contrib.slim , or try the search function .
Example #1
Source File: model_wrappers.py    From nips-2017-adversarial with MIT License 7 votes vote down vote up
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs):
    """ 
    Arguments
        ckpt_name       file name of the checkpoint
        var_scope_name  name of the variable scope
        scope           arg_scope
        constructor     constructor of the model
        input_tensor    tensor of input image
        label_offset    whether it is 1000 classes or 1001 classes, if it is 1001, remove class 0
        load_weights    whether to load weights
        kwargs 
            is_training 
            create_aux_logits 
    """
    with slim.arg_scope(scope):
        logits, endpoints = constructor(\
                input_tensor, num_classes=1000+label_offset, \
                scope=var_scope_name, **kwargs)
    if load_weights:
        init_fn = slim.assign_from_checkpoint_fn(\
                ckpt_name, slim.get_model_variables(var_scope_name))
        init_fn(K.get_session())
    return logits, endpoints 
Example #2
Source File: pretrained.py    From SSD_tensorflow_VOC with Apache License 2.0 6 votes vote down vote up
def get_init_fn(self, checkpoint_path):
        """Returns a function run by the chief worker to warm-start the training."""
        checkpoint_exclude_scopes=["InceptionV4/Logits", "InceptionV4/AuxLogits"]
        
        exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]
    
        variables_to_restore = []
        for var in slim.get_model_variables():
            excluded = False
            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    excluded = True
                    break
            if not excluded:
                variables_to_restore.append(var)
    
        return slim.assign_from_checkpoint_fn(
          checkpoint_path,
          variables_to_restore) 
Example #3
Source File: im_model.py    From tumblr-emotions with Apache License 2.0 6 votes vote down vote up
def get_init_fn(checkpoints_dir, model_name='inception_v1.ckpt'):
    """Returns a function run by the chief worker to warm-start the training.
    """
    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"]
    
    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]

    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    return slim.assign_from_checkpoint_fn(
        os.path.join(checkpoints_dir, model_name),
        variables_to_restore) 
Example #4
Source File: train.py    From AAMS with MIT License 6 votes vote down vote up
def _get_init_fn():
    vgg_checkpoint_path = "vgg_19.ckpt"
    if tf.gfile.IsDirectory(vgg_checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(vgg_checkpoint_path)
    else:
        checkpoint_path = vgg_checkpoint_path

    variables_to_restore = []
    for var in slim.get_model_variables():
        tf.logging.info('model_var: %s' % var)
        excluded = False
        for exclusion in ['vgg_19/fc']:
            if var.op.name.startswith(exclusion):
                excluded = True
                tf.logging.info('exclude:%s' % exclusion)
                break
        if not excluded:
            variables_to_restore.append(var)

    tf.logging.info('Fine-tuning from %s' % checkpoint_path)
    return slim.assign_from_checkpoint_fn(
        checkpoint_path,
        variables_to_restore,
        ignore_missing_vars=True) 
Example #5
Source File: tadam.py    From am3 with Apache License 2.0 5 votes vote down vote up
def __init__(self, model_path, batch_size):
        self.batch_size = batch_size

        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=os.path.join(model_path, 'train'))
        step = int(os.path.basename(latest_checkpoint).split('-')[1])

        flags = Namespace(load_and_save_params(default_params=dict(), exp_dir=model_path))
        image_size = get_image_size(flags.data_dir)

        with tf.Graph().as_default():
            pretrain_images_pl, pretrain_labels_pl = placeholder_inputs(
                batch_size=batch_size, image_size=image_size, scope='inputs/pretrain')
            logits = build_feat_extract_pretrain_graph(pretrain_images_pl, flags, is_training=False)

            self.pretrain_images_pl = pretrain_images_pl
            self.pretrain_labels_pl = pretrain_labels_pl

            init_fn = slim.assign_from_checkpoint_fn(
                latest_checkpoint,
                slim.get_model_variables('Model'))

            config = tf.ConfigProto(allow_soft_placement=True)
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(config=config)

            # Run init before loading the weights
            self.sess.run(tf.global_variables_initializer())
            # Load weights
            init_fn(self.sess)

            self.flags = flags
            self.logits = logits
            self.logits_size = self.logits.get_shape().as_list()[-1]
            self.step = step 
Example #6
Source File: train_utils.py    From mobile-segmentation with Apache License 2.0 5 votes vote down vote up
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_last_layer,
                      last_layers,
                      ignore_missing_vars=False):
    """Gets the function initializing model variables from a checkpoint.

    Args:
      train_logdir: Log directory for training.
      tf_initial_checkpoint: TensorFlow checkpoint for initialization.
      initialize_last_layer: Initialize last layer or not.
      last_layers: Last layers of the model.
      ignore_missing_vars: Ignore missing variables in the checkpoint.

    Returns:
      Initialization function.
    """
    if tf_initial_checkpoint is None:
        tf.logging.info('Not initializing the model from a checkpoint.')
        return None

    if tf.train.latest_checkpoint(train_logdir):
        tf.logging.info('Ignoring initialization; other checkpoint exists')
        return None

    tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

    # Variables that will not be restored.
    exclude_list = ['global_step']
    if not initialize_last_layer:
        exclude_list.extend(last_layers)

    variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list)

    if variables_to_restore:
        return slim.assign_from_checkpoint_fn(
            tf_initial_checkpoint,
            variables_to_restore,
            ignore_missing_vars=ignore_missing_vars)
    return None 
Example #7
Source File: AM3_TADAM.py    From am3 with Apache License 2.0 5 votes vote down vote up
def __init__(self, model_path, batch_size):
        self.batch_size = batch_size

        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=os.path.join(model_path, 'train'))
        step = int(os.path.basename(latest_checkpoint).split('-')[1])
        default_params = get_arguments()
        #flags = Namespace(load_and_save_params(vars(default_params), model_path))
        flags = Namespace(load_and_save_params(default_params=dict(), exp_dir=model_path))
        image_size = get_image_size(flags.data_dir)

        with tf.Graph().as_default():
            pretrain_images_pl, pretrain_labels_pl = placeholder_inputs(
                batch_size=batch_size, image_size=image_size, scope='inputs/pretrain')
            logits = build_feat_extract_pretrain_graph(pretrain_images_pl, flags, is_training=False)

            self.pretrain_images_pl = pretrain_images_pl
            self.pretrain_labels_pl = pretrain_labels_pl

            init_fn = slim.assign_from_checkpoint_fn(
                latest_checkpoint,
                slim.get_model_variables('Model'))

            config = tf.ConfigProto(allow_soft_placement=True)
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(config=config)

            # Run init before loading the weights
            self.sess.run(tf.global_variables_initializer())
            # Load weights
            init_fn(self.sess)

            self.flags = flags
            self.logits = logits
            self.logits_size = self.logits.get_shape().as_list()[-1]
            self.step = step 
Example #8
Source File: pretrained.py    From SSD_tensorflow_VOC with Apache License 2.0 5 votes vote down vote up
def use_vgg16(self):
        
        with tf.Graph().as_default():
            image_size = vgg.vgg_16.default_image_size
            img_path = "../../data/misec_images/First_Student_IC_school_bus_202076.jpg"
            checkpoint_path = "../../data/trained_models/vgg16/vgg_16.ckpt"
            
            image_string = tf.read_file(img_path)
            image = tf.image.decode_jpeg(image_string, channels=3)
            processed_image = vgg_preprocessing.preprocess_image(image, image_size, image_size, is_training=False)
            processed_images  = tf.expand_dims(processed_image, 0)
            
            # Create the model, use the default arg scope to configure the batch norm parameters.
            with slim.arg_scope(vgg.vgg_arg_scope()):
                # 1000 classes instead of 1001.
                logits, _ = vgg.vgg_16(processed_images, num_classes=1000, is_training=False)
                probabilities = tf.nn.softmax(logits)
                
                init_fn = slim.assign_from_checkpoint_fn(
                    checkpoint_path,
                    slim.get_model_variables('vgg_16'))
                
                with tf.Session() as sess:
                    init_fn(sess)
                    np_image, probabilities = sess.run([image, probabilities])
                    probabilities = probabilities[0, 0:]
                    sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])]
                    self.disp_names(sorted_inds,probabilities,include_background=False)
                    
                plt.figure()
                plt.imshow(np_image.astype(np.uint8))
                plt.axis('off')
                plt.title(img_path)
                plt.show()
        return 
Example #9
Source File: pretrained.py    From SSD_tensorflow_VOC with Apache License 2.0 5 votes vote down vote up
def use_inceptionv4(self):
        image_size = inception.inception_v4.default_image_size
        img_path = "../../data/misec_images/EnglishCockerSpaniel_simon.jpg"
        checkpoint_path = "../../data/trained_models/inception_v4/inception_v4.ckpt"

        with tf.Graph().as_default():
           
            image_string = tf.read_file(img_path)
            image = tf.image.decode_jpeg(image_string, channels=3)
            processed_image = inception_preprocessing.preprocess_image(image, image_size, image_size, is_training=False)
            processed_images  = tf.expand_dims(processed_image, 0)
            
            # Create the model, use the default arg scope to configure the batch norm parameters.
            with slim.arg_scope(inception.inception_v4_arg_scope()):
                logits, _ = inception.inception_v4(processed_images, num_classes=1001, is_training=False)
            probabilities = tf.nn.softmax(logits)
            
            init_fn = slim.assign_from_checkpoint_fn(
                checkpoint_path,
                slim.get_model_variables('InceptionV4'))
            
            with tf.Session() as sess:
                init_fn(sess)
                np_image, probabilities = sess.run([image, probabilities])
                probabilities = probabilities[0, 0:]
                sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])]
                self.disp_names(sorted_inds,probabilities)
                
            plt.figure()
            plt.imshow(np_image.astype(np.uint8))
            plt.axis('off')
            plt.title(img_path)
            plt.show()
            
            
        
        return 
Example #10
Source File: detect_camera.py    From yolo-tf with GNU Lesser General Public License v3.0 5 votes vote down vote up
def main():
    model = config.get('config', 'model')
    yolo = importlib.import_module('model.' + model)
    width = config.getint(model, 'width')
    height = config.getint(model, 'height')
    preprocess = getattr(importlib.import_module('detect'), args.preprocess)
    with tf.Session() as sess:
        ph_image = tf.placeholder(tf.float32, [1, height, width, 3], name='ph_image')
        builder = yolo.Builder(args, config)
        builder(ph_image)
        global_step = tf.contrib.framework.get_or_create_global_step()
        model_path = tf.train.latest_checkpoint(utils.get_logdir(config))
        tf.logging.info('load ' + model_path)
        slim.assign_from_checkpoint_fn(model_path, tf.global_variables())(sess)
        tf.logging.info('global_step=%d' % sess.run(global_step))
        tensors = [builder.model.conf, builder.model.xy_min, builder.model.xy_max]
        tensors = [tf.check_numerics(t, t.op.name) for t in tensors]
        cap = cv2.VideoCapture(0)
        try:
            while True:
                ret, image_bgr = cap.read()
                assert ret
                image_height, image_width, _ = image_bgr.shape
                scale = [image_width / builder.model.cell_width, image_height / builder.model.cell_height]
                image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
                image_std = np.expand_dims(preprocess(cv2.resize(image_rgb, (width, height))).astype(np.float32), 0)
                feed_dict = {ph_image: image_std}
                conf, xy_min, xy_max = sess.run(tensors, feed_dict)
                boxes = utils.postprocess.non_max_suppress(conf[0], xy_min[0], xy_max[0], args.threshold, args.threshold_iou)
                for _conf, _xy_min, _xy_max in boxes:
                    index = np.argmax(_conf)
                    if _conf[index] > args.threshold:
                        _xy_min = (_xy_min * scale).astype(np.int)
                        _xy_max = (_xy_max * scale).astype(np.int)
                        cv2.rectangle(image_bgr, tuple(_xy_min), tuple(_xy_max), (255, 0, 255), 3)
                        cv2.putText(image_bgr, builder.names[index] + ' (%.1f%%)' % (_conf[index] * 100), tuple(_xy_min), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
                cv2.imshow('detection', image_bgr)
                cv2.waitKey(1)
        finally:
            cv2.destroyAllWindows()
            cap.release() 
Example #11
Source File: detect.py    From yolo-tf with GNU Lesser General Public License v3.0 5 votes vote down vote up
def main():
    model = config.get('config', 'model')
    yolo = importlib.import_module('model.' + model)
    width = config.getint(model, 'width')
    height = config.getint(model, 'height')
    with tf.Session() as sess:
        image = tf.placeholder(tf.float32, [1, height, width, 3], name='image')
        builder = yolo.Builder(args, config)
        builder(image)
        global_step = tf.contrib.framework.get_or_create_global_step()
        model_path = tf.train.latest_checkpoint(utils.get_logdir(config))
        tf.logging.info('load ' + model_path)
        slim.assign_from_checkpoint_fn(model_path, tf.global_variables())(sess)
        tf.logging.info('global_step=%d' % sess.run(global_step))
        path = os.path.expanduser(os.path.expandvars(args.path))
        if os.path.isfile(path):
            detect(sess, builder.model, builder.names, image, path)
            plt.show()
        else:
            for dirpath, _, filenames in os.walk(path):
                for filename in filenames:
                    if os.path.splitext(filename)[-1].lower() in args.exts:
                        _path = os.path.join(dirpath, filename)
                        print(_path)
                        detect(sess, builder.model, builder.names, image, _path)
                        plt.show() 
Example #12
Source File: model_wrappers.py    From nips-2017-adversarial with MIT License 5 votes vote down vote up
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs):
    """ kwargs are is_training and create_aux_logits """
    print(var_scope_name)
    with slim.arg_scope(scope):
        logits, endpoints = constructor(\
                input_tensor, num_classes=1000+label_offset, \
                scope=var_scope_name, **kwargs)
    if load_weights:
        init_fn = slim.assign_from_checkpoint_fn(\
                ckpt_name, slim.get_model_variables(var_scope_name))
        init_fn(K.get_session())
    return logits, endpoints 
Example #13
Source File: model_wrappers.py    From nips-2017-adversarial with MIT License 5 votes vote down vote up
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs):
    """ kwargs are is_training and create_aux_logits """
    print(var_scope_name)
    with slim.arg_scope(scope):
        logits, endpoints = constructor(\
                input_tensor, num_classes=1000+label_offset, \
                scope=var_scope_name, **kwargs)
    if load_weights:
        init_fn = slim.assign_from_checkpoint_fn(\
                ckpt_name, slim.get_model_variables(var_scope_name))
        init_fn(K.get_session())
    return logits, endpoints 
Example #14
Source File: model_wrappers.py    From nips-2017-adversarial with MIT License 5 votes vote down vote up
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs):
    """ kwargs are is_training and create_aux_logits """
    print(var_scope_name)
    with slim.arg_scope(scope):
        logits, endpoints = constructor(\
                input_tensor, num_classes=1000+label_offset, \
                scope=var_scope_name, **kwargs)
    if load_weights:
        init_fn = slim.assign_from_checkpoint_fn(\
                ckpt_name, slim.get_model_variables(var_scope_name))
        init_fn(K.get_session())
    return logits, endpoints 
Example #15
Source File: summary_utils.py    From sact with Apache License 2.0 4 votes vote down vote up
def export_to_h5(checkpoint_dir, export_path, images, end_points, num_samples,
                 batch_size, sact):
  """Exports ponder cost maps and other useful info to an HDF5 file."""
  output_file = h5py.File(export_path, 'w')

  output_file.attrs['block_scopes'] = end_points['block_scopes']
  keys_to_tensors = {}
  for block_scope in end_points['block_scopes']:
    for k in ('{}/ponder_cost'.format(block_scope),
              '{}/num_units'.format(block_scope),
              '{}/halting_distribution'.format(block_scope),
              '{}/flops'.format(block_scope)):
      keys_to_tensors[k] = end_points[k]
  keys_to_tensors['images'] = images
  keys_to_tensors['flops'] = end_points['flops']

  if sact:
    keys_to_tensors['ponder_cost_map'] = sact_map(end_points, 'ponder_cost')
    keys_to_tensors['num_units_map'] = sact_map(end_points, 'num_units')

  keys_to_datasets = {}
  for key, tensor in keys_to_tensors.iteritems():
    sh = tensor.get_shape().as_list()
    sh[0] = num_samples
    print(key, sh)
    keys_to_datasets[key] = output_file.create_dataset(
        key, sh, compression='lzf')

  variables_to_restore = slim.get_model_variables()
  checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
  assert checkpoint_path is not None
  init_fn = slim.assign_from_checkpoint_fn(checkpoint_path,
                                           variables_to_restore)

  sv = tf.train.Supervisor(
      graph=tf.get_default_graph(),
      logdir=None,
      summary_op=None,
      summary_writer=None,
      global_step=None,
      saver=None)

  assert num_samples % batch_size == 0
  num_batches = num_samples // batch_size

  with sv.managed_session('', start_standard_services=False) as sess:
    init_fn(sess)
    sv.start_queue_runners(sess)

    for i in range(num_batches):
      tf.logging.info('Evaluating batch %d/%d', i + 1, num_batches)
      end_points_out = sess.run(keys_to_tensors)
      for key, dataset in keys_to_datasets.iteritems():
        dataset[i * batch_size:(i + 1) * batch_size, ...] = end_points_out[key] 
Example #16
Source File: train_object_detector.py    From MobileNet with Apache License 2.0 4 votes vote down vote up
def _get_init_fn():
  """Returns a function run by the chief worker to warm-start the training.

  Note that the init_fn is only run when initializing the model during the very
  first global step.

  Returns:
    An init function run by the supervisor.
  """
  if FLAGS.checkpoint_path is None:
    return None

  # Warn the user if a checkpoint exists in the train_dir. Then we'll be
  # ignoring the checkpoint anyway.
  if tf.train.latest_checkpoint(FLAGS.train_dir):
    tf.logging.info(
      'Ignoring --checkpoint_path because a checkpoint already exists in %s'
      % FLAGS.train_dir)
    return None

  exclusions = []
  if FLAGS.checkpoint_exclude_scopes:
    exclusions = [scope.strip()
                  for scope in FLAGS.checkpoint_exclude_scopes.split(',')]

  # TODO(sguada) variables.filter_variables()
  variables_to_restore = []
  for var in slim.get_model_variables():
    excluded = False
    for exclusion in exclusions:
      if var.op.name.startswith(exclusion):
        excluded = True
        break
    if not excluded:
      variables_to_restore.append(var)

  if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
    checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
  else:
    checkpoint_path = FLAGS.checkpoint_path

  tf.logging.info('Fine-tuning from %s' % checkpoint_path)

  return slim.assign_from_checkpoint_fn(
    checkpoint_path,
    variables_to_restore,
    ignore_missing_vars=FLAGS.ignore_missing_vars) 
Example #17
Source File: demo_detect.py    From yolo-tf with GNU Lesser General Public License v3.0 4 votes vote down vote up
def main():
    model = config.get('config', 'model')
    cachedir = utils.get_cachedir(config)
    with open(os.path.join(cachedir, 'names'), 'r') as f:
        names = [line.strip() for line in f]
    width = config.getint(model, 'width')
    height = config.getint(model, 'height')
    yolo = importlib.import_module('model.' + model)
    cell_width, cell_height = utils.calc_cell_width_height(config, width, height)
    tf.logging.info('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height))
    with tf.Session() as sess:
        paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile]
        num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths)
        tf.logging.warn('num_examples=%d' % num_examples)
        image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config)
        image_std = tf.image.per_image_standardization(image_rgb)
        image_rgb = tf.cast(image_rgb, tf.uint8)
        ph_image = tf.placeholder(image_std.dtype, [1] + image_std.get_shape().as_list(), name='ph_image')
        global_step = tf.contrib.framework.get_or_create_global_step()
        builder = yolo.Builder(args, config)
        builder(ph_image)
        variables_to_restore = slim.get_variables_to_restore()
        ph_labels = [tf.placeholder(l.dtype, [1] + l.get_shape().as_list(), name='ph_' + l.op.name) for l in labels]
        with tf.name_scope('total_loss') as name:
            builder.create_objectives(ph_labels)
            total_loss = tf.losses.get_total_loss(name=name)
        tf.global_variables_initializer().run()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)
        _image_rgb, _image_std, _labels = sess.run([image_rgb, image_std, labels])
        coord.request_stop()
        coord.join(threads)
        feed_dict = dict([(ph, np.expand_dims(d, 0)) for ph, d in zip(ph_labels, _labels)])
        feed_dict[ph_image] = np.expand_dims(_image_std, 0)
        logdir = utils.get_logdir(config)
        assert os.path.exists(logdir)
        model_path = tf.train.latest_checkpoint(logdir)
        tf.logging.info('load ' + model_path)
        slim.assign_from_checkpoint_fn(model_path, variables_to_restore)(sess)
        tf.logging.info('global_step=%d' % sess.run(global_step))
        tf.logging.info('total_loss=%f' % sess.run(total_loss, feed_dict))
        _ = Drawer(sess, names, builder.model.cell_width, builder.model.cell_height, _image_rgb, _labels, builder.model, feed_dict)
        plt.show() 
Example #18
Source File: train_classifier_mgr.py    From SSD_tensorflow_VOC with Apache License 2.0 4 votes vote down vote up
def _get_init_fn():
    """Returns a function run by the chief worker to warm-start the training.

    Note that the init_fn is only run when initializing the model during the very
    first global step.

    Returns:
        An init function run by the supervisor.
    """
    if FLAGS.checkpoint_path is None:
        return None

    # Warn the user if a checkpoint exists in the train_dir. Then we'll be
    # ignoring the checkpoint anyway.
    if tf.train.latest_checkpoint(FLAGS.train_dir):
        tf.logging.info(
                'Ignoring --checkpoint_path because a checkpoint already exists in %s'
                % FLAGS.train_dir)
        return None

    exclusions = []
    if FLAGS.checkpoint_exclude_scopes:
        exclusions = [scope.strip()
                                    for scope in FLAGS.checkpoint_exclude_scopes.split(',')]

    # TODO(sguada) variables.filter_variables()
    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
        checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Fine-tuning from %s' % checkpoint_path)

    return slim.assign_from_checkpoint_fn(
            checkpoint_path,
            variables_to_restore,
            ignore_missing_vars=FLAGS.ignore_missing_vars) 
Example #19
Source File: slim_train_test.py    From SSD_tensorflow_VOC with Apache License 2.0 4 votes vote down vote up
def __get_init_fn(self):
        """Returns a function run by the chief worker to warm-start the training.
    
        Note that the init_fn is only run when initializing the model during the very
        first global step.
    
        Returns:
            An init function run by the supervisor.
        """
        
        if self.checkpoint_path is None:
            return None
    
        # Warn the user if a checkpoint exists in the train_dir. Then we'll be
        # ignoring the checkpoint anyway.
        if tf.train.latest_checkpoint(self.train_dir):
            tf.logging.info(
                    'Ignoring --checkpoint_path because a checkpoint already exists in %s'
                    % self.train_dir)
            return None
    
        exclusions = []
        if self.checkpoint_exclude_scopes:
            exclusions = [scope.strip()
                                        for scope in self.checkpoint_exclude_scopes.split(',')]
    
        # TODO(sguada) variables.filter_variables()
        variables_to_restore = []
        for var in slim.get_model_variables():
            excluded = False
            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    excluded = True
                    break
            if not excluded:
                variables_to_restore.append(var)
    
        if tf.gfile.IsDirectory(self.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(self.checkpoint_path)
        else:
            checkpoint_path = self.checkpoint_path
    
        tf.logging.info('Fine-tuning from %s' % checkpoint_path)
    
        return slim.assign_from_checkpoint_fn(
                checkpoint_path,
                variables_to_restore,
                ignore_missing_vars=self.ignore_missing_vars) 
Example #20
Source File: bgsCNN_v5.py    From bgsCNN with GNU General Public License v3.0 4 votes vote down vote up
def train(self):
        img_size = [self.image_height, self.image_width, self.image_depth]
        train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
                    batch_size = self.train_batch_size,
                    capacity = 2000,
                    num_threads = 2,
                    min_after_dequeue = 1000)
        test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
                    batch_size = self.test_batch_size,
                    capacity = 500,
                    num_threads = 2,
                    min_after_dequeue = 300)
        init = tf.global_variables_initializer()
        init_fn = slim.assign_from_checkpoint_fn("vgg_16.ckpt", slim.get_model_variables('vgg_16'))
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            init_fn(sess)
            train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
            test_writer  = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
            for iter in range(self.max_iteration):
                inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
                # train with dynamic learning rate
                if iter <= 500:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-4, self.is_training:True})
                elif iter <= self.max_iteration - 1000:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:0.5e-4, self.is_training:True})
                else:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-5, self.is_training:True})
                # print training loss and test loss
                if iter%10 == 0:
                    summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:False})
                    train_writer.add_summary(summary_train, iter)
                    train_writer.flush()
                    summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test, self.is_training:False})
                    test_writer.add_summary(summary_test, iter)
                    test_writer.flush()
                # record training loss and test loss
                if iter%10 == 0:
                    train_loss  = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:False})
                    test_loss   = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test, self.is_training:False})
                    print("iter step %d trainning batch loss %f"%(iter, train_loss))
                    print("iter step %d test loss %f\n"%(iter, test_loss))
                # record model
                if iter%100 == 0:
                    saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
            coord.request_stop()
            coord.join(threads) 
Example #21
Source File: bgsCNN_v3.py    From bgsCNN with GNU General Public License v3.0 4 votes vote down vote up
def train(self):
        img_size = [self.image_height, self.image_width, self.image_depth]
        train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
                    batch_size = self.train_batch_size,
                    capacity = 3000,
                    num_threads = 2,
                    min_after_dequeue = 1000)
        test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
                    batch_size = self.test_batch_size,
                    capacity = 500,
                    num_threads = 2,
                    min_after_dequeue = 300)
        init = tf.global_variables_initializer()
        init_fn = slim.assign_from_checkpoint_fn("resnet_v2_50.ckpt", slim.get_model_variables('resnet_v2'))
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            init_fn(sess)
            train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
            test_writer  = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
            for iter in range(self.max_iteration):
                inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
                # train with dynamic learning rate
                if iter <= 500:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-3, self.batch_size:self.train_batch_size})
                elif iter <= self.max_iteration - 1000:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:0.5e-3, self.batch_size:self.train_batch_size})
                else:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-4, self.batch_size:self.train_batch_size})
                # print training loss and test loss
                if iter%10 == 0:
                    summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train,
                                             self.batch_size:self.train_batch_size})
                    train_writer.add_summary(summary_train, iter)
                    train_writer.flush()
                    summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test,
                                             self.batch_size:self.test_batch_size})
                    test_writer.add_summary(summary_test, iter)
                    test_writer.flush()
                # record training loss and test loss
                if iter%10 == 0:
                    train_loss  = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                                    self.batch_size:self.train_batch_size})
                    test_loss   = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test,
                                                    self.batch_size:self.test_batch_size})
                    print("iter step %d trainning batch loss %f"%(iter, train_loss))
                    print("iter step %d test loss %f\n"%(iter, test_loss))
                # record model
                if iter%100 == 0:
                    saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
            coord.request_stop()
            coord.join(threads) 
Example #22
Source File: bgsCNN_v4.py    From bgsCNN with GNU General Public License v3.0 4 votes vote down vote up
def train(self):
        img_size = [self.image_height, self.image_width, self.image_depth]
        train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
                    batch_size = self.train_batch_size,
                    capacity = 2000,
                    num_threads = 2,
                    min_after_dequeue = 1000)
        test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
                    batch_size = self.test_batch_size,
                    capacity = 500,
                    num_threads = 2,
                    min_after_dequeue = 300)
        init = tf.global_variables_initializer()
        init_fn = slim.assign_from_checkpoint_fn("vgg_16.ckpt", slim.get_model_variables('vgg_16'))
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            init_fn(sess)
            train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
            test_writer  = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
            for iter in range(self.max_iteration):
                inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
                # train with dynamic learning rate
                if iter <= 500:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:True,
                                    self.learning_rate:1e-4})
                elif iter <= self.max_iteration - 1000:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:True,
                                    self.learning_rate:0.5e-4})
                else:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:True,
                                    self.learning_rate:1e-5})
                # print training loss and test loss
                if iter%10 == 0:
                    summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train,
                                             self.is_training:False})
                    train_writer.add_summary(summary_train, iter)
                    train_writer.flush()
                    summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test,
                                            self.is_training:False})
                    test_writer.add_summary(summary_test, iter)
                    test_writer.flush()
                # record training loss and test loss
                if iter%10 == 0:
                    train_loss  = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                                    self.is_training:False})
                    test_loss   = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test,
                                                    self.is_training:False})
                    print("iter step %d trainning batch loss %f"%(iter, train_loss))
                    print("iter step %d test loss %f\n"%(iter, test_loss))
                # record model
                if iter%100 == 0:
                    saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
            coord.request_stop()
            coord.join(threads) 
Example #23
Source File: train_model.py    From SSD_tensorflow_VOC with Apache License 2.0 4 votes vote down vote up
def __get_init_fn(self):
        """Returns a function run by the chief worker to warm-start the training.
    
        Note that the init_fn is only run when initializing the model during the very
        first global step.
    
        Returns:
            An init function run by the supervisor.
        """  
        
        if self.checkpoint_path is None:
            return None
    
        # Warn the user if a checkpoint exists in the train_dir. Then we'll be
        # ignoring the checkpoint anyway.
        
        
        if tf.train.latest_checkpoint(self.train_dir):
            tf.logging.info(
                    'Ignoring --checkpoint_path because a checkpoint already exists in %s'
                    % self.train_dir)
            return None
    
        exclusions = []
        if self.checkpoint_exclude_scopes:
            exclusions = [scope.strip()
                                        for scope in self.checkpoint_exclude_scopes.split(',')]
    
        # TODO(sguada) variables.filter_variables()
        variables_to_restore = []
        all_variables = slim.get_model_variables()
        if self.fine_tune_vgg16:
            global_step = slim.get_or_create_global_step()
            all_variables.append(global_step)
        for var in all_variables:
            excluded = False
            
            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    excluded = True
                    break
            if not excluded:
                variables_to_restore.append(var)
    
        if tf.gfile.IsDirectory(self.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(self.checkpoint_path)
        else:
            checkpoint_path = self.checkpoint_path
    
        tf.logging.info('Fine-tuning from %s' % checkpoint_path)
    
        return slim.assign_from_checkpoint_fn(
                checkpoint_path,
                variables_to_restore,
                ignore_missing_vars=self.ignore_missing_vars) 
Example #24
Source File: generation_builder.py    From DAGAN with MIT License 4 votes vote down vote up
def run_experiment(self):
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(self.init)
            self.writer = tf.summary.FileWriter(self.log_path, graph=tf.get_default_graph())
            self.saver = tf.train.Saver()
            if self.continue_from_epoch != -1:
                checkpoint = "{}/{}_{}.ckpt".format(self.saved_models_filepath, self.experiment_name,
                                                    self.continue_from_epoch)
                variables_to_restore = []
                for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                    print(var)
                    variables_to_restore.append(var)

                tf.logging.info('Fine-tuning from %s' % checkpoint)

                fine_tune = slim.assign_from_checkpoint_fn(
                    checkpoint,
                    variables_to_restore,
                    ignore_missing_vars=True)
                fine_tune(sess)
            if self.spherical_interpolation:
                z_vectors = interpolations.create_mine_grid(rows=self.num_generations, cols=self.num_generations,
                                                            dim=100, space=3, anchors=None,
                                                            spherical=True, gaussian=True)
            else:
                z_vectors = np.random.normal(size=(self.num_generations * self.num_generations, self.z_dim))

            with tqdm.tqdm(total=self.total_gen_batches) as pbar_samp:
                for i in range(self.total_gen_batches):
                    x_gen_a = self.data.get_gen_batch()
                    sample_two_dimensions_generator(sess=sess,
                                                    same_images=self.same_images,
                                                    inputs=x_gen_a,
                                                    data=self.data, batch_size=self.batch_size, z_input=self.z_input,
                                                    file_name="{}/generation_z_spherical_{}".format(self.save_image_path,
                                                                                                  self.experiment_name),
                                                    input_a=self.input_x_i, training_phase=self.training_phase,
                                                    dropout_rate=self.dropout_rate,
                                                    dropout_rate_value=self.dropout_rate_value,
                                                    z_vectors=z_vectors)
                    pbar_samp.update(1) 
Example #25
Source File: bgsCNN_v2.py    From bgsCNN with GNU General Public License v3.0 4 votes vote down vote up
def train(self):
        img_size = [self.image_height, self.image_width, self.image_depth]
        train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
                    batch_size = self.train_batch_size,
                    capacity = 3000,
                    num_threads = 2,
                    min_after_dequeue = 1000)
        test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
                    batch_size = self.test_batch_size,
                    capacity = 500,
                    num_threads = 2,
                    min_after_dequeue = 300)
        init = tf.global_variables_initializer()
        init_fn = slim.assign_from_checkpoint_fn("resnet_v2_50.ckpt", slim.get_model_variables('resnet_v2'))
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            init_fn(sess)
            train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
            test_writer  = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
            for iter in range(self.max_iteration):
                inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
                # train with dynamic learning rate
                if iter <= 500:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-3, self.batch_size:self.train_batch_size})
                elif iter <= self.max_iteration - 1000:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:0.5e-3, self.batch_size:self.train_batch_size})
                else:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-4, self.batch_size:self.train_batch_size})
                # print training loss and test loss
                if iter%10 == 0:
                    summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train,
                                             self.batch_size:self.train_batch_size})
                    train_writer.add_summary(summary_train, iter)
                    train_writer.flush()
                    summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test,
                                             self.batch_size:self.test_batch_size})
                    test_writer.add_summary(summary_test, iter)
                    test_writer.flush()
                # record training loss and test loss
                if iter%10 == 0:
                    train_loss  = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                                    self.batch_size:self.train_batch_size})
                    test_loss   = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test,
                                                    self.batch_size:self.test_batch_size})
                    print("iter step %d trainning batch loss %f"%(iter, train_loss))
                    print("iter step %d test loss %f\n"%(iter, test_loss))
                # record model
                if iter%100 == 0:
                    saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
            coord.request_stop()
            coord.join(threads) 
Example #26
Source File: bgsCNN_v1.py    From bgsCNN with GNU General Public License v3.0 4 votes vote down vote up
def train(self):
        img_size = [self.image_height, self.image_width, self.image_depth]
        train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
                    batch_size = self.train_batch_size,
                    capacity = 3000,
                    num_threads = 2,
                    min_after_dequeue = 1000)
        test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
                    batch_size = self.test_batch_size,
                    capacity = 500,
                    num_threads = 2,
                    min_after_dequeue = 300)
        init = tf.global_variables_initializer()
        init_fn = slim.assign_from_checkpoint_fn("resnet_v2_50.ckpt", slim.get_model_variables('resnet_v2'))
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(init)
            init_fn(sess)
            train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
            test_writer  = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
            for iter in range(self.max_iteration):
                inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
                # train with dynamic learning rate
                if iter <= 500:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-3, self.batch_size:self.train_batch_size})
                elif iter <= self.max_iteration - 1000:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:0.5e-3, self.batch_size:self.train_batch_size})
                else:
                    self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                    self.learning_rate:1e-4, self.batch_size:self.train_batch_size})
                # print training loss and test loss
                if iter%10 == 0:
                    summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train,
                                             self.batch_size:self.train_batch_size})
                    train_writer.add_summary(summary_train, iter)
                    train_writer.flush()
                    summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test,
                                             self.batch_size:self.test_batch_size})
                    test_writer.add_summary(summary_test, iter)
                    test_writer.flush()
                # record training loss and test loss
                if iter%10 == 0:
                    train_loss  = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train,
                                                    self.batch_size:self.train_batch_size})
                    test_loss   = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test,
                                                    self.batch_size:self.test_batch_size})
                    print("iter step %d trainning batch loss %f"%(iter, train_loss))
                    print("iter step %d test loss %f\n"%(iter, test_loss))
                # record model
                if iter%100 == 0:
                    saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
            coord.request_stop()
            coord.join(threads)