Python torchfile.load() Examples

The following are 30 code examples of torchfile.load(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torchfile , or try the search function .
Example #1
Source File: data.py    From tagan with Apache License 2.0 6 votes vote down vote up
def _load_dataset(self, img_root, caption_root, classes_filename, word_embedding):
        output = []
        with open(os.path.join(caption_root, classes_filename)) as f:
            lines = f.readlines()
            for line in lines:
                cls = line.replace('\n', '')
                filenames = os.listdir(os.path.join(caption_root, cls))
                for filename in filenames:
                    datum = torchfile.load(os.path.join(caption_root, cls, filename))
                    raw_desc = datum.char
                    desc, len_desc = self._get_word_vectors(raw_desc, word_embedding, self.max_word_length)
                    output.append({
                        'img': os.path.join(img_root, datum.img),
                        'desc': desc,
                        'len_desc': len_desc
                    })
        return output 
Example #2
Source File: trainer.py    From multiple-objects-gan with MIT License 6 votes vote down vote up
def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,  map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    # ############# For training stageII GAN  ############# 
Example #3
Source File: mk_dataset.py    From im2recipe-Pytorch with MIT License 6 votes vote down vote up
def get_st(file):
    info = torchfile.load(file)

    ids = info[b'ids']

    imids = []
    for i,id in enumerate(ids):
        imids.append(''.join(chr(i) for i in id))

    st_vecs = {}
    st_vecs['encs'] = info['encs']
    st_vecs['rlens'] = info['rlens']
    st_vecs['rbps'] = info['rbps']
    st_vecs['ids'] = imids

    print(np.shape(st_vecs['encs']),len(st_vecs['rlens']),len(st_vecs['rbps']),len(st_vecs['ids']))
    return st_vecs

# ============================================================================= 
Example #4
Source File: vgg_face.py    From vgg-face.pytorch with MIT License 6 votes vote down vote up
def load_weights(self, path="pretrained/VGG_FACE.t7"):
        """ Function to load luatorch pretrained

        Args:
            path: path for the luatorch pretrained
        """
        model = torchfile.load(path)
        counter = 1
        block = 1
        for i, layer in enumerate(model.modules):
            if layer.weight is not None:
                if block <= 5:
                    self_layer = getattr(self, "conv_%d_%d" % (block, counter))
                    counter += 1
                    if counter > self.block_size[block - 1]:
                        counter = 1
                        block += 1
                    self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
                    self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...]
                else:
                    self_layer = getattr(self, "fc%d" % (block))
                    block += 1
                    self_layer.weight.data[...] = torch.tensor(layer.weight).view_as(self_layer.weight)[...]
                    self_layer.bias.data[...] = torch.tensor(layer.bias).view_as(self_layer.bias)[...] 
Example #5
Source File: mk_dataset.py    From im2recipe with MIT License 6 votes vote down vote up
def get_st(file):
    info = torchfile.load(file)

    ids = info['ids']

    imids = []
    for i,id in enumerate(ids):
        imids.append(''.join(chr(i) for i in id))

    st_vecs = {}
    st_vecs['encs'] = info['encs']
    st_vecs['rlens'] = info['rlens']
    st_vecs['rbps'] = info['rbps']
    st_vecs['ids'] = imids

    print(np.shape(st_vecs['encs']),len(st_vecs['rlens']),len(st_vecs['rbps']),len(st_vecs['ids']))
    return st_vecs 
Example #6
Source File: train.py    From TFSegmentation with Apache License 2.0 6 votes vote down vote up
def load_overfit_data(self):
        print("Loading data..")
        self.train_data = {'X': np.load(self.args.data_dir + "X_train.npy"),
                           'Y': np.load(self.args.data_dir + "Y_train.npy")}
        self.train_data_len = self.train_data['X'].shape[0] - self.train_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_training_per_epoch = (
                                                         self.train_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Train-shape-x -- " + str(self.train_data['X'].shape))
        print("Train-shape-y -- " + str(self.train_data['Y'].shape))
        print("Num of iterations in one epoch -- " + str(self.num_iterations_training_per_epoch))
        print("Overfitting data is loaded")

        print("Loading Validation data..")
        self.val_data = self.train_data
        self.val_data_len = self.val_data['X'].shape[0] - self.val_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_validation_per_epoch = (
                                                           self.val_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " + str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " + str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded") 
Example #7
Source File: train.py    From TFSegmentation with Apache License 2.0 6 votes vote down vote up
def load_train_data_h5(self):
        print("Loading Training data..")
        self.train_data = h5py.File(self.args.data_dir + self.args.h5_train_file, 'r')
        self.train_data_len = self.args.h5_train_len
        self.num_iterations_training_per_epoch = (
                                                         self.train_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Train-shape-x -- " + str(self.train_data['X'].shape) + " " + str(self.train_data_len))
        print("Train-shape-y -- " + str(self.train_data['Y'].shape))
        print("Num of iterations on training data in one epoch -- " + str(self.num_iterations_training_per_epoch))
        print("Training data is loaded")

        print("Loading Validation data..")
        self.val_data = {'X': np.load(self.args.data_dir + "X_val.npy"),
                         'Y': np.load(self.args.data_dir + "Y_val.npy")}
        self.val_data_len = self.val_data['X'].shape[0] - self.val_data['X'].shape[0] % self.args.batch_size
        self.num_iterations_validation_per_epoch = (
                                                           self.val_data_len + self.args.batch_size - 1) // self.args.batch_size
        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " + str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " + str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded") 
Example #8
Source File: tests.py    From python-torchfile with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_classnames_never_decoded(self):
        obj = load('custom_class.t7', utf8_decode_strings=True)
        self.assertNotIsInstance(obj.torch_typename(), unicode_type)

        obj = load('custom_class.t7', utf8_decode_strings=False)
        self.assertNotIsInstance(obj.torch_typename(), unicode_type) 
Example #9
Source File: trainer.py    From StackGAN-Pytorch with MIT License 5 votes vote down vote up
def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = \
                torch.load(cfg.STAGE1_G,
                           map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict)
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD 
Example #10
Source File: tests.py    From python-torchfile with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_basic_tensors(self):
        f64 = load('doubletensor.t7')
        self.assertTrue((f64 == np.array([[1, 2, 3, ], [4, 5, 6.9]],
                                         dtype=np.float64)).all())

        f32 = load('floattensor.t7')
        self.assertAlmostEqual(f32.sum(), 12.97241666913, delta=1e-5) 
Example #11
Source File: tests.py    From python-torchfile with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_dict_accessors(self):
        obj = load('hello=123.t7',
                   use_int_heuristic=True,
                   utf8_decode_strings=True)
        self.assertIsInstance(obj['hello'], int)
        self.assertIsInstance(obj.hello, int)

        obj = load('hello=123.t7',
                   use_int_heuristic=True,
                   utf8_decode_strings=False)
        self.assertIsInstance(obj[b'hello'], int)
        self.assertIsInstance(obj.hello, int) 
Example #12
Source File: tests.py    From python-torchfile with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_recursive_class(self):
        obj = load('recursive_class.t7')
        self.assertEqual(obj.a, obj) 
Example #13
Source File: tests.py    From python-torchfile with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_recursive_table(self):
        obj = load('recursive_kv_table.t7')
        # both the key and value point to itself:
        key, = obj.keys()
        self.assertEqual(key, obj)
        self.assertEqual(obj[key], obj) 
Example #14
Source File: tests.py    From python-torchfile with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_hash(self):
        obj = load('tds_hash.t7')
        self.assertEqual(len(obj), 3)
        self.assertEqual(obj[1], 2)
        self.assertEqual(obj[10], 11) 
Example #15
Source File: tests.py    From python-torchfile with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_vec(self):
        # Should not be affected by list heuristic at all
        vec = load('tds_vec.t7', use_list_heuristic=False)
        self.assertEqual(vec, [123, 456]) 
Example #16
Source File: tests.py    From python-torchfile with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_int_heuristic(self):
        obj = load('hello=123.t7', use_int_heuristic=True)
        self.assertIsInstance(obj[b'hello'], int)

        obj = load('hello=123.t7', use_int_heuristic=False)
        self.assertNotIsInstance(obj[b'hello'], int)

        obj = load('list_table.t7',
                   use_list_heuristic=False,
                   use_int_heuristic=False)
        self.assertEqual(
            dict(obj),
            {1: b'hello', 2: b'world', 3: b'third item', 4: 123})
        self.assertNotIsInstance(list(obj.keys())[0], int) 
Example #17
Source File: vis.py    From im2recipe with MIT License 5 votes vote down vote up
def load_layer(json_file):
    with open(json_file) as f_layer:
        return json.load(f_layer) 
Example #18
Source File: convert_t7.py    From crnn.pytorch with MIT License 5 votes vote down vote up
def torch_to_pytorch(model, t7_file, output):
    py_layers = []
    for layer in list(model.children()):
        py_layer_serial(layer, py_layers)

    t7_data = torchfile.load(t7_file)
    t7_layers = []
    for layer in t7_data:
        torch_layer_serial(layer, t7_layers)

    j = 0
    for i, py_layer in enumerate(py_layers):
        py_name = type(py_layer).__name__
        t7_layer = t7_layers[j]
        t7_name = t7_layer[0].split('.')[-1]
        if layer_map[t7_name] != py_name:
            raise RuntimeError('%s does not match %s' % (py_name, t7_name))

        if py_name == 'LSTM':
            n_layer = 2 if py_layer.bidirectional else 1
            n_layer *= py_layer.num_layers
            t7_layer = t7_layers[j:j + n_layer]
            j += n_layer
        else:
            j += 1

        load_params(py_layer, t7_layer)

    torch.save(model.state_dict(), output) 
Example #19
Source File: convert_t7.py    From crnn with MIT License 5 votes vote down vote up
def torch_to_pytorch(model, t7_file, output):
    py_layers = []
    for layer in list(model.children()):
        py_layer_serial(layer, py_layers)

    t7_data = torchfile.load(t7_file)
    t7_layers = []
    for layer in t7_data:
        torch_layer_serial(layer, t7_layers)

    j = 0
    for i, py_layer in enumerate(py_layers):
        py_name = type(py_layer).__name__
        t7_layer = t7_layers[j]
        t7_name = t7_layer[0].split('.')[-1]
        if layer_map[t7_name] != py_name:
            raise RuntimeError('%s does not match %s' % (py_name, t7_name))

        if py_name == 'LSTM':
            n_layer = 2 if py_layer.bidirectional else 1
            n_layer *= py_layer.num_layers
            t7_layer = t7_layers[j:j + n_layer]
            j += n_layer
        else:
            j += 1

        load_params(py_layer, t7_layer)

    torch.save(model.state_dict(), output) 
Example #20
Source File: convert_t7.py    From basicOCR with GNU General Public License v3.0 5 votes vote down vote up
def torch_to_pytorch(model, t7_file, output):
    py_layers = []
    for layer in list(model.children()):
        py_layer_serial(layer, py_layers)

    t7_data = torchfile.load(t7_file)
    t7_layers = []
    for layer in t7_data:
        torch_layer_serial(layer, t7_layers)

    j = 0
    for i, py_layer in enumerate(py_layers):
        py_name = type(py_layer).__name__
        t7_layer = t7_layers[j]
        t7_name = t7_layer[0].split('.')[-1]
        if layer_map[t7_name] != py_name:
            raise RuntimeError('%s does not match %s' % (py_name, t7_name))

        if py_name == 'LSTM':
            n_layer = 2 if py_layer.bidirectional else 1
            n_layer *= py_layer.num_layers
            t7_layer = t7_layers[j:j + n_layer]
            j += n_layer
        else:
            j += 1

        load_params(py_layer, t7_layer)

    torch.save(model.state_dict(), output) 
Example #21
Source File: data.py    From tagan with Apache License 2.0 5 votes vote down vote up
def convert_and_save(self, caption_root, word_embedding, max_word_length):
        with open(os.path.join(caption_root, 'allclasses.txt'), 'r') as f:
            classes = f.readlines()
        for cls in classes:
            cls = cls[:-1]
            os.makedirs(caption_root + '_vec/' + cls)
            filenames = os.listdir(os.path.join(caption_root, cls))
            for filename in filenames:
                datum = torchfile.load(os.path.join(caption_root, cls, filename))
                raw_desc = datum.char
                desc, len_desc = self._get_word_vectors(raw_desc, word_embedding, max_word_length)
                torch.save({'img': datum.img, 'word_vec': desc, 'len_desc': len_desc},
                            os.path.join(caption_root + '_vec', cls, filename[:-2] + 'pth')) 
Example #22
Source File: data.py    From tagan with Apache License 2.0 5 votes vote down vote up
def _load_dataset(self, img_root, caption_root, classes_filename):
        output = []
        with open(os.path.join(caption_root, classes_filename)) as f:
            lines = f.readlines()
            for line in lines:
                cls = line.replace('\n', '')
                filenames = os.listdir(os.path.join(caption_root + '_vec', cls))
                for filename in filenames:
                    datum = torch.load(os.path.join(caption_root + '_vec', cls, filename))
                    output.append({
                        'img': os.path.join(bytes(img_root, 'utf-8'), datum['img']),
                        'word_vec': datum['word_vec'],
                        'len_desc': datum['len_desc']
                    })
        return output 
Example #23
Source File: tests.py    From python-torchfile with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_dict(self):
        obj = load('hello=123.t7')
        self.assertEqual(dict(obj), {b'hello': 123}) 
Example #24
Source File: trainer.py    From StackGAN-Pytorch with MIT License 5 votes vote down vote up
def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    # ############# For training stageII GAN  ############# 
Example #25
Source File: torch.py    From tensorflow-litterbox with Apache License 2.0 5 votes vote down vote up
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('torch_file')
    args = parser.parse_args()
    torch_file = args.torch_file

    data = torchfile.load(torch_file, force_8bytes_long=True)

    if data.modules:
        process_obj(data) 
Example #26
Source File: trainer.py    From multiple-objects-gan with MIT License 5 votes vote down vote up
def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = torch.load(cfg.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = torch.load(cfg.STAGE1_G, map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD 
Example #27
Source File: convert_t7.py    From crnn-pytorch with MIT License 5 votes vote down vote up
def torch_to_pytorch(model, t7_file, output):
    py_layers = []
    for layer in list(model.children()):
        py_layer_serial(layer, py_layers)

    t7_data = torchfile.load(t7_file)
    t7_layers = []
    for layer in t7_data:
        torch_layer_serial(layer, t7_layers)

    j = 0
    for i, py_layer in enumerate(py_layers):
        py_name = type(py_layer).__name__
        t7_layer = t7_layers[j]
        t7_name = t7_layer[0].split('.')[-1]
        if layer_map[t7_name] != py_name:
            raise RuntimeError('%s does not match %s' % (py_name, t7_name))

        if py_name == 'LSTM':
            n_layer = 2 if py_layer.bidirectional else 1
            n_layer *= py_layer.num_layers
            t7_layer = t7_layers[j:j + n_layer]
            j += n_layer
        else:
            j += 1

        load_params(py_layer, t7_layer)

    torch.save(model.state_dict(), output) 
Example #28
Source File: load_t7.py    From SoundNet-tensorflow with MIT License 5 votes vote down vote up
def load(o, param_list):
    """ Get torch7 weights into numpy array """
    try:
        num = len(o['modules'])
    except:
        num = 0
    
    for i in xrange(num):
        # 2D conv
        if o['modules'][i]._typename == 'nn.SpatialConvolution' or \
            o['modules'][i]._typename == 'cudnn.SpatialConvolution':
            temp = {'weights': o['modules'][i]['weight'].transpose((2,3,1,0)),
                    'biases': o['modules'][i]['bias']}
            param_list.append(temp)
        # 2D deconv
        elif o['modules'][i]._typename == 'nn.SpatialFullConvolution':
            temp = {'weights': o['modules'][i]['weight'].transpose((2,3,1,0)),
                    'biases': o['modules'][i]['bias']}
            param_list.append(temp)
        # 3D conv
        elif o['modules'][i]._typename == 'nn.VolumetricFullConvolution':
            temp = {'weights': o['modules'][i]['weight'].transpose((2,3,4,1,0)),
                    'biases': o['modules'][i]['bias']}
            param_list.append(temp)
        # batch norm
        elif o['modules'][i]._typename == 'nn.SpatialBatchNormalization' or \
            o['modules'][i]._typename == 'nn.VolumetricBatchNormalization':
            param_list[-1]['gamma'] = o['modules'][i]['weight']
            param_list[-1]['beta'] = o['modules'][i]['bias']
            param_list[-1]['mean'] = o['modules'][i]['running_mean']
            param_list[-1]['var'] = o['modules'][i]['running_var']

        load(o['modules'][i], param_list) 
Example #29
Source File: load_t7.py    From tf_videogan with MIT License 5 votes vote down vote up
def load(o, param_list):
	try:
		num = len(o['modules'])
	except:
		num = 0

	for i in xrange(num):
		# 2D conv
		if o['modules'][i]._typename == 'nn.SpatialFullConvolution':
			temp = {'weights': o['modules'][i]['weight'].transpose((2,3,1,0)),
			'biases': o['modules'][i]['bias']}
			param_list.append(temp)
		# 3D conv
		elif o['modules'][i]._typename == 'nn.VolumetricFullConvolution':
			temp = {'weights': o['modules'][i]['weight'].transpose((2,3,4,1,0)),
			'biases': o['modules'][i]['bias']}
			param_list.append(temp)
		# batch norm
		elif o['modules'][i]._typename == 'nn.SpatialBatchNormalization' or o['modules'][i]._typename == 'nn.VolumetricBatchNormalization':
			# temp = {'gamma': o['modules'][i]['weight'],
			# 'beta': o['modules'][i]['bias']}
			# param_list.append(temp)
			param_list[-1]['gamma'] = o['modules'][i]['weight']
			param_list[-1]['beta'] = o['modules'][i]['bias']

		load(o['modules'][i], param_list) 
Example #30
Source File: train.py    From TFSegmentation with Apache License 2.0 5 votes vote down vote up
def init_tfdata(self, batch_size, main_dir, resize_shape, mode='train'):
        self.data_session = tf.Session()
        print("Creating the iterator for training data")
        with tf.device('/cpu:0'):
            segdl = SegDataLoader(main_dir, batch_size, (resize_shape[0], resize_shape[1]), resize_shape,
                                  # * 2), resize_shape,
                                  'data/cityscapes_tfdata/train.txt')
            iterator = Iterator.from_structure(segdl.data_tr.output_types, segdl.data_tr.output_shapes)
            next_batch = iterator.get_next()

            self.init_op = iterator.make_initializer(segdl.data_tr)
            self.data_session.run(self.init_op)

        print("Loading Validation data in memoryfor faster training..")
        self.val_data = {'X': np.load(self.args.data_dir + "X_val.npy"),
                         'Y': np.load(self.args.data_dir + "Y_val.npy")}
        # self.crop()
        # import cv2
        # cv2.imshow('crop1', self.val_data['X'][0,:,:,:])
        # cv2.imshow('crop2', self.val_data['X'][1,:,:,:])
        # cv2.imshow('seg1', self.val_data['Y'][0,:,:])
        # cv2.imshow('seg2', self.val_data['Y'][1,:,:])
        # cv2.waitKey()

        self.val_data_len = self.val_data['X'].shape[0] - self.val_data['X'].shape[0] % self.args.batch_size
        #        self.num_iterations_validation_per_epoch = (
        #                                                       self.val_data_len + self.args.batch_size - 1) // self.args.batch_size
        self.num_iterations_validation_per_epoch = self.val_data_len // self.args.batch_size

        print("Val-shape-x -- " + str(self.val_data['X'].shape) + " " + str(self.val_data_len))
        print("Val-shape-y -- " + str(self.val_data['Y'].shape))
        print("Num of iterations on validation data in one epoch -- " + str(self.num_iterations_validation_per_epoch))
        print("Validation data is loaded")

        return next_batch, segdl.data_len