Python dataset.Dataset() Examples

The following are 30 code examples of dataset.Dataset(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module dataset , or try the search function .
Example #1
Source File: train.py    From GeneGAN with GNU General Public License v3.0 6 votes vote down vote up
def main():
    parser = argparse.ArgumentParser(description='test', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        '-a', '--attribute', 
        default='Smiling',
        type=str,
        help='Specify attribute name for training. \ndefault: %(default)s. \nAll attributes can be found in list_attr_celeba.txt'
    )
    parser.add_argument(
        '-g', '--gpu', 
        default='0',
        type=str,
        help='Specify GPU id. \ndefault: %(default)s. \nUse comma to seperate several ids, for example: 0,1'
    )
    args = parser.parse_args()

    celebA = Dataset(args.attribute)
    GeneGAN = Model(is_train=True)
    run(config, celebA, GeneGAN, gpu=args.gpu) 
Example #2
Source File: train.py    From DNA-GAN with MIT License 6 votes vote down vote up
def main():
    parser = argparse.ArgumentParser(description='test', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        '-a', '--attributes',
        nargs='+',
        type=str,
        help='Specify attribute name for training. \nAll attributes can be found in list_attr_celeba.txt'
    )
    parser.add_argument(
        '-g', '--gpu',
        default='0',
        type=str,
        help='Specify GPU id. \ndefault: %(default)s. \nUse comma to seperate several ids, for example: 0,1'
    )
    args = parser.parse_args()

    celebA = Dataset(args.attributes)
    DNA_GAN = Model(args.attributes, is_train=True)
    run(config, celebA, DNA_GAN, gpu=args.gpu) 
Example #3
Source File: user.py    From OpenUBA with GNU General Public License v3.0 6 votes vote down vote up
def extract_users(dataset_session: DatasetSession, log_metadata_obj: dict) -> List:
        ############## TESTS
        # get dataset
        log_file_dataset: Dataset = dataset_session.get_csv_dataset()
        # get core dataframe
        log_file_core_dataframe: CoreDataFrame = log_file_dataset.get_dataframe()
        # get data frame (.data)
        log_file_dataframe: pd.DataFrame = log_file_core_dataframe.data
        # test: get shape
        log_file_shape: Tuple = log_file_dataframe.shape
        logging.warning("execute(): dataframe shape: "+str(log_file_shape))
        ############

        logging.info("ExtractAllUsersCSV: extract_users log_file_data.columns: - "+str(log_file_dataframe.columns))
        logging.info("ExtractAllUsersCSV: extract_users log_metadata_obj: - "+str(log_metadata_obj))


        id_column: pd.Series = log_file_dataframe[ log_metadata_obj["id_feature"]]
        logging.info( "ExtractAllUsersCSV, extract_users, id_column, len of column: "+str(len(id_column)) )
        user_set: List = np.unique( log_file_dataframe[ log_metadata_obj["id_feature"] ].fillna("NA") )
        logging.info( "ExtractAllUsersCSV, extract_users, user_set len of column: "+str(len(user_set)) )
        logging.error(user_set)
        return user_set 
Example #4
Source File: holoclean.py    From HoloClean-Legacy-deprecated with Apache License 2.0 6 votes vote down vote up
def __init__(self, holo_env, name="session"):
        """
        Constructor for Holoclean session
        :param holo_env: Holoclean object
        :param name: Name for the Holoclean session
        """
        logging.basicConfig()

        # Initialize members
        self.name = name
        self.holo_env = holo_env
        self.Denial_constraints = []  # Denial Constraint strings
        self.dc_objects = {}  # Denial Constraint Objects
        self.featurizers = []
        self.error_detectors = []
        self.cv = None
        self.pruning = None
        self.dataset = Dataset()
        self.parser = ParserInterface(self)
        self.inferred_values = None
        self.feature_count = 0 
Example #5
Source File: compute_one.py    From wasserstein-dist with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  # tf.logging.set_verbosity(tf.logging.INFO)

  # load two copies of the dataset
  print('Loading datasets...')
  subset1 = Dataset(bs=FLAGS.batch_size, filepattern=FLAGS.filepattern)
  subset2 = Dataset(bs=FLAGS.batch_size, filepattern=FLAGS.filepattern)

  print('Computing Wasserstein distance...')
  with tf.Graph().as_default():
    # compute Wasserstein distance between two sets of examples
    wasserstein = Wasserstein(subset1, subset2)
    loss = wasserstein.dist(C=.1, nsteps=FLAGS.loss_steps)
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      res = sess.run(loss)
      print('result: %f\n' % res) 
Example #6
Source File: compute_all.py    From wasserstein-dist with Apache License 2.0 6 votes vote down vote up
def main(unused_argv):
  # tf.logging.set_verbosity(tf.logging.INFO)

  # load two copies of the dataset
  print('Loading datasets...')
  dataset = [Dataset(bs=FLAGS.batch_size, filepattern=FLAGS.filepattern,
                     label=i) for i in range(10)]

  print('Computing Wasserstein distance(s)...')
  for i in range(10):
    for j in range(10):
      with tf.Graph().as_default():
        # compute Wasserstein distance between sets of labels i and j
        wasserstein = Wasserstein(dataset[i], dataset[j])
        loss = wasserstein.dist(C=.1, nsteps=FLAGS.loss_steps)
        with tf.Session() as sess:
          sess.run(tf.global_variables_initializer())
          res = sess.run(loss)
          print_flush('%f ' % res)
    print_flush('\n') 
Example #7
Source File: data.py    From adversarial-object-removal with MIT License 6 votes vote down vote up
def split_train_val_test(data_dir, img_size=256):
    df = pd.read_csv(
        join(data_dir, 'list_eval_partition.txt'),
        delim_whitespace=True, header=None
    )
    filenames, labels = df.values[:, 0], df.values[:, 1]

    train_filenames = filenames[labels == 0]
    valid_filenames = filenames[labels == 1]
    test_filenames  = filenames[labels == 2]

    train_set = Dataset(
        data_dir, train_filenames, input_transform_augment(178, img_size),
        target_transform(), target_transform_binary()
    )
    valid_set = Dataset(
        data_dir, valid_filenames, input_transform(178, img_size),
        target_transform(), target_transform_binary()
    )
    test_set = Dataset(
        data_dir, test_filenames, input_transform(178, img_size),
        target_transform(), target_transform_binary()
    )

    return train_set, valid_set, test_set 
Example #8
Source File: inference_color.py    From im2avatar with MIT License 5 votes vote down vote up
def main():
  test_dataset = dataset.Dataset(base_path=FLAGS.base_dir, 
                                  cat_id=FLAGS.cat_id, 
                                  data_list_path=FLAGS.data_list_path)
  inference(test_dataset) 
Example #9
Source File: train_color.py    From im2avatar with MIT License 5 votes vote down vote up
def main():
  train_dataset = dataset.Dataset(base_path=FLAGS.base_dir, 
                                  cat_id=FLAGS.cat_id, 
                                  data_list_path=FLAGS.data_list_path)

  train(train_dataset) 
Example #10
Source File: tracker_init.py    From rpg_feature_tracking_analysis with MIT License 5 votes vote down vote up
def init_on_track(self, root, config, dataset_yaml):
        print("[1/3] Loading tracks in %s." % os.path.basename(config["tracks_csv"]))
        tracks = np.genfromtxt(self.config["tracks_csv"])

        # check that all features start at the same timestamp, if not, discard features that occur later
        first_len_tracks = len(tracks)
        valid_ids, tracks = filter_first_tracks(tracks, filter_too_short=True)

        if len(tracks) < first_len_tracks:
            print("WARNING: This package only supports evaluation of tracks which have been initialized at the same"
                  " time. All tracks except the first have been discarded.")

        tracks_dict = {i: tracks[tracks[:, 0] == i, 1:] for i in valid_ids}

        print("[2/3] Loading frame dataset to find positions of initial tracks.")
        frame_dataset = Dataset(root, dataset_yaml, dataset_type="frames")

        # find dataset indices for each start
        tracks_init = {}
        print("[3/3] Initializing tracks")
        for track_id, track in tracks_dict.items():
            frame_dataset.set_to_first_after(track[0,0])
            t_dataset, _ = frame_dataset.current()

            x_interp = np.interp(t_dataset, track[:, 0], track[:, 1])
            y_interp = np.interp(t_dataset, track[:, 0], track[:, 2])

            tracks_init[track_id] = np.array([[t_dataset, x_interp, y_interp]])

        tracks_obj = Tracks(tracks_init)

        return tracks_obj, {"frame_dataset": frame_dataset, "reference_track": tracks} 
Example #11
Source File: inference_shape.py    From im2avatar with MIT License 5 votes vote down vote up
def main():
  test_dataset = dataset.Dataset(base_path=FLAGS.base_dir, 
                                  cat_id=FLAGS.cat_id, 
                                  data_list_path=FLAGS.data_list_path)
  inference(test_dataset) 
Example #12
Source File: test.py    From PoseFix_RELEASE with MIT License 5 votes vote down vote up
def test(test_model):
    
    # annotation load
    d = Dataset()
    annot = d.load_annot(cfg.testset)
    
    # input pose load
    input_pose = d.input_pose_load(annot, cfg.testset)

    # job assign (multi-gpu)
    from tfflat.mp_utils import MultiProc
    img_start = 0
    ranges = [0]
    img_num = len(np.unique([i['image_id'] for i in input_pose]))
    images_per_gpu = int(img_num / len(args.gpu_ids.split(','))) + 1
    for run_img in range(img_num):
        img_end = img_start + 1
        while img_end < len(input_pose) and input_pose[img_end]['image_id'] == input_pose[img_start]['image_id']:
            img_end += 1
        if (run_img + 1) % images_per_gpu == 0 or (run_img + 1) == img_num:
            ranges.append(img_end)
        img_start = img_end

    def func(gpu_id):
        cfg.set_args(args.gpu_ids.split(',')[gpu_id])
        tester = Tester(Model(), cfg)
        tester.load_weights(test_model)
        range = [ranges[gpu_id], ranges[gpu_id + 1]]
        return test_net(tester, input_pose, range, gpu_id)

    MultiGPUFunc = MultiProc(len(args.gpu_ids.split(',')), func)
    result = MultiGPUFunc.work()

    # evaluation
    d.evaluation(result, annot, cfg.result_dir, cfg.testset) 
Example #13
Source File: train.py    From Keras-progressive_growing_of_gans with MIT License 5 votes vote down vote up
def load_dataset(dataset_spec=None, verbose=True, **spec_overrides):
    if verbose: print('Loading dataset...')
    if dataset_spec is None: dataset_spec = config.dataset
    dataset_spec = dict(dataset_spec) # take a copy of the dict before modifying it
    dataset_spec.update(spec_overrides)
    dataset_spec['h5_path'] = os.path.join(config.data_dir, dataset_spec['h5_path'])
    if 'label_path' in dataset_spec: dataset_spec['label_path'] = os.path.join(config.data_dir, dataset_spec['label_path'])
    training_set = dataset.Dataset(**dataset_spec)
    if verbose: print('Dataset shape =', np.int32(training_set.shape).tolist())
    drange_orig = training_set.get_dynamic_range()
    if verbose: print('Dynamic range =', drange_orig)
    return training_set, drange_orig 
Example #14
Source File: train_shape.py    From im2avatar with MIT License 5 votes vote down vote up
def main():
  train_dataset = dataset.Dataset(base_path=FLAGS.base_dir, 
                                  cat_id=FLAGS.cat_id, 
                                  data_list_path=FLAGS.data_list_path)

  train(train_dataset) 
Example #15
Source File: analysis.py    From PCWG with MIT License 5 votes vote down vote up
def define_columns(self):

        self.nameColumn = "Dataset Name"
        self.windSpeedBin = "Wind Speed Bin"
        self.dataCount = "Data Count"
        self.powerStandDev = "Power Standard Deviation"
        self.powerCoeff = "Power Coefficient"
        self.measuredTurbulencePower = 'Measured TI Corrected Power'
        self.measuredTurbPowerCurveInterp = 'Measured TI Corrected Power Curve Interp'
        self.measuredPowerCurveInterp = 'All Measured Power Curve Interp'
        self.baseline_wind_speed = "Baseline Wind Speed" 
Example #16
Source File: analysis.py    From PCWG with MIT License 5 votes vote down vote up
def load_dataset(self, dataset_config):
        return dataset.Dataset(dataset_config) 
Example #17
Source File: test_final_stage.py    From chainRec with Apache License 2.0 5 votes vote down vote up
def test_dataset(DATA_NAME, n_stage, method, embed_size, lbda):
    myData = Dataset(DATA_NAME, n_stage)
    myData.split_train_test(seed=1234, max_validation_test_samples=100000)

    validation_samples = myData.sampling_validation()
    
    if method == "chainRec_uniform":
        training_samples = myData.sampling_training(method="edgeOpt_uniform")
        
        myModel = chainRec(myData.n_user, myData.n_item, myData.n_stage, myData.DATA_NAME)
        myModel.config_global(MODEL_CLASS="chainRec_uniform", HIDDEN_DIM=embed_size, 
                              LAMBDA=lbda, LEARNING_RATE=0.001, BATCH_SIZE=256,
                              target_stage_id=(n_stage-1))
        myModel.load_samples(training_samples, validation_samples)
        myModel.train_edgeOpt()
        
        myModel.evaluation(myData.data_test, myData.user_item_map, topK=10)
        
    elif method == "chainRec_stage":
        training_samples = myData.sampling_training(method="edgeOpt_stage")
        
        myModel = chainRec(myData.n_user, myData.n_item, myData.n_stage, myData.DATA_NAME)
        myModel.config_global(MODEL_CLASS="edgeOpt_stage", HIDDEN_DIM=embed_size, 
                              LAMBDA=lbda, LEARNING_RATE=0.001, BATCH_SIZE=256,
                              target_stage_id=(n_stage-1))
        myModel.load_samples(training_samples, validation_samples)
        myModel.train_edgeOpt()
        
        myModel.evaluation(myData.data_test, myData.user_item_map, topK=10)       
        
    elif method == "bprMF":
        training_samples = myData.sampling_training(method="sliceOpt")
        
        myModel = bprMF(myData.n_user, myData.n_item, myData.n_stage, myData.DATA_NAME)
        myModel.config_global(MODEL_CLASS="bprMF", HIDDEN_DIM=embed_size, 
                              LAMBDA=lbda, LEARNING_RATE=0.001, BATCH_SIZE=256,
                              target_stage_id=(n_stage-1))
        myModel.load_samples(training_samples, validation_samples)
        myModel.train_sliceOpt()
        
        myModel.evaluation(myData.data_test, myData.user_item_map, topK=10) 
Example #18
Source File: main.py    From -Learn-Artificial-Intelligence-with-TensorFlow with MIT License 5 votes vote down vote up
def input_fn(hparams, mode):
    return dataset.Dataset(hparams.processed_data_dir, hparams).make_batch(mode) 
Example #19
Source File: main.py    From snip-public with MIT License 5 votes vote down vote up
def main():
    args = parse_arguments()

    # Dataset
    dataset = Dataset(**vars(args))

    # Reset the default graph and set a graph-level seed
    tf.reset_default_graph()
    tf.set_random_seed(9)

    # Model
    model = Model(num_classes=dataset.num_classes, **vars(args))
    model.construct_model()

    # Session
    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()
    tf.local_variables_initializer().run()

    # Prune
    prune.prune(args, model, sess, dataset)

    # Train and test
    train.train(args, model, sess, dataset)
    test.test(args, model, sess, dataset)

    sess.close()
    sys.exit() 
Example #20
Source File: 3_5_classification_part_two.py    From -Learn-Artificial-Intelligence-with-TensorFlow with MIT License 5 votes vote down vote up
def input_fn(hparams, mode):
    with tf.variable_scope('input_fn'):
        return dataset.Dataset(hparams.processed_data_dir, hparams).make_batch(mode) 
Example #21
Source File: 3_1_embeddings.py    From -Learn-Artificial-Intelligence-with-TensorFlow with MIT License 5 votes vote down vote up
def input_fn(data_dir, params):
    text_batch, _ = dataset.Dataset(data_dir, params).make_batch(params.batch_size)
    return text_batch, text_batch 
Example #22
Source File: 3_4_classification_part_one.py    From -Learn-Artificial-Intelligence-with-TensorFlow with MIT License 5 votes vote down vote up
def input_fn(hparams, mode):
    with tf.variable_scope('input_fn'):
        return dataset.Dataset(hparams.processed_data_dir, hparams).make_batch(mode) 
Example #23
Source File: process.py    From OpenUBA with GNU General Public License v3.0 5 votes vote down vote up
def process_data(self, data_folder: str, log_data_obj: dict) -> DatasetSession:

        logging.warning("Processing Data for : "+str(data_folder))

        log_name = log_data_obj["log_name"]
        log_type = log_data_obj["type"]
        delimiter = log_data_obj["delimiter"]
        location_type = log_data_obj["location_type"]
        folder = log_data_obj["folder"]
        id_feature = log_data_obj["id_feature"]

        dataset_session: DatasetSession = DatasetSession(log_type)

        #read dataset, if any new
        if log_type == DataSourceFileType.CSV.value:

            # invoke datasetsession to read the csv
            dataset_session.read_csv(data_folder, folder, location_type, delimiter) # load
            print( "isinstance(dataset_session.dataset, Dataset): "+str(isinstance(dataset_session.csv_dataset, Dataset)) )
            dataset_size: Tuple = dataset_session.get_csv_size()
            logging.info( "Dataset Session size: "+str(dataset_size) )

        elif log_type == DataSourceFileType.FLAT.value:
            pass
        elif log_type == DataSourceFileType.PARQUET.value:
            pass
        elif log_type == DataSourceFileType.JSON.value:
            pass

        return dataset_session 
Example #24
Source File: process.py    From UBA with GNU General Public License v3.0 5 votes vote down vote up
def process_data(self, data_folder: str, log_data_obj: dict):

        logging.warning("Processing Data for : "+str(data_folder))

        log_name = log_data_obj["log_name"]
        log_type = log_data_obj["type"]
        location_type = log_data_obj["location_type"]
        folder = log_data_obj["folder"]

        dataset_session = DatasetSession(log_type)

        '''
         STEP1: check for new datasets
         from folder directory
        '''

        #read dataset, if any new
        if log_type == "csv":
            dataset_session.read_csv(data_folder, folder, location_type) # load
            print( "isinstance(dataset_session.dataset, Dataset): "+str(isinstance(dataset_session.dataset, Dataset)) )
            dataset_size: Tuple = dataset_session.get_size()
            logging.warning( "Dataset Session size: "+str(dataset_size) )

        # after read the data, perform entity analysis using Entity class

        # adjust risk per entity 
Example #25
Source File: main.py    From SHN-based-2D-face-alignment with MIT License 5 votes vote down vote up
def main(config):
    # For fast training.
    cudnn.benchmark = True

    # Create directories if not exist.
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    if not os.path.exists(config.model_save_dir):
        os.makedirs(config.model_save_dir)

    imgdirs_train = ['data/afw/', 'data/helen/trainset/', 'data/lfpw/trainset/']
    imgdirs_test_commomset = ['data/helen/testset/','data/lfpw/testset/']

    # Dataset and Dataloader
    if config.phase == 'test':
        trainset=None
        train_loader = None
    else:
        trainset = Dataset(imgdirs_train, config.phase, 'train', config.rotFactor, config.res, config.gamma)
        train_loader = data.DataLoader(trainset,
                                       batch_size=config.batch_size,
                                       shuffle=True,
                                       num_workers=config.num_workers,
                                       pin_memory=True)
    testset = Dataset(imgdirs_test_commomset, 'test', config.attr, config.rotFactor, config.res, config.gamma)
    test_loader = data.DataLoader(testset,
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  pin_memory=True)
    
    # Solver for training and testing.
    solver = Solver(train_loader, test_loader, config)
    if config.phase == 'train':
        solver.train()
    else:
        solver.load_state_dict(config.best_model)
        solver.test() 
Example #26
Source File: test.py    From cgd with Apache License 2.0 5 votes vote down vote up
def test(opt):
    # Load dataset
    dataset = Dataset(opt.data_dir, opt.train_txt, opt.test_txt, opt.bbox_txt)
    dataset.print_stats()

    # Load image transform
    test_transform = transforms.Compose([
        transforms.Resize((opt.image_width, opt.image_height)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Load data loader
    test_loader = mx.gluon.data.DataLoader(
        dataset=ImageData(dataset.test, test_transform),
        batch_size=opt.batch_size,
        num_workers=opt.num_workers
    )

    # Load model
    model = Model(opt)

    # Load evaluator
    evaluator = Evaluator(model, test_loader, opt.ctx)

    # Evaluate
    recalls = evaluator.evaluate(ranks=opt.recallk)
    for recallk, recall in zip(opt.recallk, recalls):
        print("R@{:4d}: {:.4f}".format(recallk, recall)) 
Example #27
Source File: check.py    From rl-attack-detection with MIT License 5 votes vote down vote up
def main(args):
    with tf.Graph().as_default() as graph:
        # Create dataset
        logging.info('Create data flow from %s' % args.train)
        train_data = Dataset(directory=args.train, mean_path=args.mean, batch_size=args.batch_size, num_threads=2, capacity=10000)
    
        # Create initializer
        init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
         
        # Config session
        config = get_config(args)
        
        # Setup summary
        check_summary_writer = tf.summary.FileWriter(os.path.join(args.log, 'check'), graph)

        check_op = tf.cast(train_data()['x_t_1'] * 255.0 + train_data()['mean'], tf.uint8)
 
        tf.summary.image('x_t_1_batch_restore', check_op, collections=['check'])
        check_summary_op = tf.summary.merge_all('check')

        # Start session
        with tf.Session(config=config) as sess:
            coord = tf.train.Coordinator()
            sess.run(init)
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            for i in range(10):
                x_t_1_batch, summary = sess.run([check_op, check_summary_op])
                check_summary_writer.add_summary(summary, i)
            coord.request_stop()
            coord.join(threads) 
Example #28
Source File: cmdline.py    From geoinference with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_build_dataset(self):
		dirname = os.path.dirname(__file__)
		dataset_dir = os.path.join(dirname,'__tmp_dataset')
		posts_file = os.path.join(dirname,'posts.json.gz')

		app.build_dataset([dataset_dir,posts_file,'user_id','mentions'])

		# check if it's there
		ds = dataset.Dataset(dataset_dir)
		self.assertEquals(len(list(ds.post_iter())),15)

		os.system('rm -rf %s' % dataset_dir) 
Example #29
Source File: holoclean.py    From holoclean with Apache License 2.0 5 votes vote down vote up
def __init__(self, env, name="session"):
        """
        Constructor for Holoclean session
        :param env: Holoclean environment
        :param name: Name for the Holoclean session
        """
        # use DEBUG logging level if verbose enabled
        if env['verbose']:
            root_logger.setLevel(logging.DEBUG)
            gensim_logger.setLevel(logging.DEBUG)

        logging.debug('initiating session with parameters: %s', env)

        # Initialize random seeds.
        random.seed(env['seed'])
        torch.manual_seed(env['seed'])
        np.random.seed(seed=env['seed'])

        # Initialize members
        self.name = name
        self.env = env
        self.ds = Dataset(name, env)
        self.dc_parser = Parser(env, self.ds)
        self.domain_engine = DomainEngine(env, self.ds)
        self.detect_engine = DetectEngine(env, self.ds)
        self.repair_engine = RepairEngine(env, self.ds)
        self.eval_engine = EvalEngine(env, self.ds) 
Example #30
Source File: pred.py    From atec-nlp with MIT License 5 votes vote down vote up
def main(input_file, output_file):
    print("\nPredicting...\n")
    graph = tf.Graph()
    with graph.as_default():  # with tf.Graph().as_default() as g:
        sess = tf.Session()
        with sess.as_default():
            # Load the saved meta graph and restore variables
            # saver = tf.train.Saver(tf.global_variables())
            meta_file = os.path.abspath(os.path.join(FLAGS.model_dir, 'checkpoints/model-3400.meta'))
            new_saver = tf.train.import_meta_graph(meta_file)
            new_saver.restore(sess, tf.train.latest_checkpoint(os.path.join(FLAGS.model_dir, 'checkpoints')))
            # graph = tf.get_default_graph()

            # Get the placeholders from the graph by name
            # input_x1 = graph.get_operation_by_name("input_x1").outputs[0]
            input_x1 = graph.get_tensor_by_name("input_x1:0")  # Tensor("input_x1:0", shape=(?, 15), dtype=int32)
            input_x2 = graph.get_tensor_by_name("input_x2:0")
            dropout_keep_prob = graph.get_tensor_by_name("dropout_keep_prob:0")
            # Tensors we want to evaluate
            y_pred = graph.get_tensor_by_name("metrics/y_pred:0")
            # vars = tf.get_collection('vars')
            # for var in vars:
            #     print(var)

            e = graph.get_tensor_by_name("cosine:0")

            # Generate batches for one epoch
            dataset = Dataset(data_file=input_file, is_training=False)
            data = dataset.process_data(data_file=input_file, sequence_length=FLAGS.max_document_length)
            batches = dataset.batch_iter(data, FLAGS.batch_size, 1, shuffle=False)
            with open(output_file, 'w') as fo:
                lineno = 1
                for batch in batches:
                    x1_batch, x2_batch, _, _ = zip(*batch)
                    y_pred_ = sess.run([y_pred], {input_x1: x1_batch, input_x2: x2_batch, dropout_keep_prob: 1.0})
                    for pred in y_pred_[0]:
                        fo.write('{}\t{}\n'.format(lineno, pred))
                        lineno += 1