Python utils.DataLoader() Examples
The following are 3
code examples of utils.DataLoader().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
utils
, or try the search function
.
Example #1
Source File: model.py From Pix2Pix-Keras with MIT License | 6 votes |
def __init__(self): self.nH = 256 self.nW = 256 self.nC = 3 self.data_loader = DataLoader() self.image_shape = (self.nH, self.nW, self.nC) self.image_A = Input(shape=self.image_shape) self.image_B = Input(shape=self.image_shape) self.discriminator = self.creat_discriminator() self.discriminator.compile(loss='mse', optimizer=Adam(0.0002, 0.5), metrics=['accuracy']) self.generator = self.creat_generator() self.fake_A = self.generator(self.image_B) self.discriminator.trainable = False self.valid = self.discriminator([self.fake_A, self.image_B]) self.combined = Model(inputs=[self.image_A, self.image_B], outputs=[self.valid, self.fake_A]) self.combined.compile(loss=['mse', 'mae'], loss_weights=[1, 100], optimizer=Adam(0.0002, 0.5)) # Calculate output shape of D (PatchGAN) self.disc_patch = (int(self.nH/2**4), int(self.nW/2**4), 1) pass
Example #2
Source File: train.py From Anime-Super-Resolution with MIT License | 6 votes |
def __init__(self, scale=4, num_res_blocks=32, pretrained_weights=None, name=None): self.scale = scale self.num_res_blocks = num_res_blocks self.model = wdsr_b(scale=scale, num_res_blocks=num_res_blocks) self.model.compile(optimizer=AdamWithWeightsNormalization(lr=0.001), \ loss=self.mae, metrics=[self.psnr]) if pretrained_weights != None: self.model.load_weights(pretrained_weights) print("[OK] weights loaded.") pass self.data_loader = DataLoader(scale=scale, crop_size=256) self.pretrained_weights = pretrained_weights self.default_weights_save_path = 'weights/wdsr-b-' + \ str(self.num_res_blocks) + '-x' + str(self.scale) + '.h5' self.name = name pass
Example #3
Source File: train.py From Social_lstm_pedestrian_prediction with GNU General Public License v3.0 | 4 votes |
def train(args): datasets = range(4) # Remove the leaveDataset from datasets datasets.remove(args.leaveDataset) # Create the data loader object. This object would preprocess the data in terms of # batches each of size args.batch_size, of length args.seq_length data_loader = DataLoader(args.batch_size, args.seq_length, datasets, forcePreProcess=True) # Save the arguments int the config file with open(os.path.join('save_lstm', 'config.pkl'), 'wb') as f: pickle.dump(args, f) # Create a Vanilla LSTM model with the arguments model = Model(args) # Initialize a TensorFlow session with tf.Session() as sess: # Initialize all the variables in the graph sess.run(tf.initialize_all_variables()) # Add all the variables to the list of variables to be saved saver = tf.train.Saver(tf.all_variables()) # For each epoch for e in range(args.num_epochs): # Assign the learning rate (decayed acc. to the epoch number) sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) # Reset the pointers in the data loader object data_loader.reset_batch_pointer() # Get the initial cell state of the LSTM state = sess.run(model.initial_state) # For each batch in this epoch for b in range(data_loader.num_batches): # Tic start = time.time() # Get the source and target data of the current batch # x has the source data, y has the target data x, y = data_loader.next_batch() # Feed the source, target data and the initial LSTM state to the model feed = {model.input_data: x, model.target_data: y, model.initial_state: state} # Fetch the loss of the model on this batch, the final LSTM state from the session train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed) # Toc end = time.time() # Print epoch, batch, loss and time taken print( "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" .format( e * data_loader.num_batches + b, args.num_epochs * data_loader.num_batches, e, train_loss, end - start)) # Save the model if the current epoch and batch number match the frequency if (e * data_loader.num_batches + b) % args.save_every == 0 and ((e * data_loader.num_batches + b) > 0): checkpoint_path = os.path.join('save_lstm', 'model.ckpt') saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b) print("model saved to {}".format(checkpoint_path))