Python options.train_options.TrainOptions() Examples
The following are 5
code examples of options.train_options.TrainOptions().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
options.train_options
, or try the search function
.
Example #1
Source File: train.py From GANimation with GNU General Public License v3.0 | 6 votes |
def __init__(self): self._opt = TrainOptions().parse() data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True) data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False) self._dataset_train = data_loader_train.load_data() self._dataset_test = data_loader_test.load_data() self._dataset_train_size = len(data_loader_train) self._dataset_test_size = len(data_loader_test) print('#train images = %d' % self._dataset_train_size) print('#test images = %d' % self._dataset_test_size) self._model = ModelsFactory.get_by_name(self._opt.model, self._opt) self._tb_visualizer = TBVisualizer(self._opt) self._train()
Example #2
Source File: train.py From DMIT with MIT License | 5 votes |
def main(): opt = TrainOptions().parse() data_loader = CreateDataLoader(opt) dataset_size = len(data_loader) * opt.batch_size visualizer = Visualizer(opt) model = create_model(opt) start_epoch = model.start_epoch total_steps = start_epoch*dataset_size for epoch in range(start_epoch+1, opt.niter+opt.niter_decay+1): epoch_start_time = time.time() model.update_lr() save_result = True for i, data in enumerate(data_loader): iter_start_time = time.time() total_steps += opt.batch_size epoch_iter = total_steps - dataset_size * (epoch - 1) model.prepare_data(data) model.update_model() if save_result or total_steps % opt.display_freq == 0: save_result = save_result or total_steps % opt.update_html_freq == 0 visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result) save_result = False if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batch_size visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) print('epoch {} cost dime {}'.format(epoch,time.time()-epoch_start_time)) model.save_ckpt(epoch) model.save_generator('latest') if epoch % opt.save_epoch_freq == 0: print('saving the generator at the end of epoch {}, iters {}'.format(epoch, total_steps)) model.save_generator(epoch)
Example #3
Source File: collect-rotation-ditri.py From pose-adv-aug with Apache License 2.0 | 4 votes |
def main(): opt = TrainOptions().parse() if opt.sr_dir == '': print('sr directory is null.') exit() sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id, opt.sr_dir+'-'+opt.load_prefix_pose[0:-1]) if not os.path.isdir(sr_pretrain_dir): os.makedirs(sr_pretrain_dir) train_history = ASNTrainHistory() # print(train_history.lr) # exit() checkpoint_hg = Checkpoint() # visualizer = Visualizer(opt) # log_name = opt.resume_prefix_pose + 'log.txt' # visualizer.log_path = sr_pretrain_dir + '/' + log_name train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt' train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt' # train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt' # train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt' val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt' val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt' # val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt' # val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt' if opt.dataset == 'mpii': num_classes = 16 os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id hg = model.create_hg(num_stacks=2, num_modules=1, num_classes=num_classes, chan=256) hg = torch.nn.DataParallel(hg).cuda() if opt.load_prefix_pose == '': print('please input the checkpoint name of the pose model') # exit() # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose) checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.load_prefix_pose)[0:-1] checkpoint_hg.load_checkpoint(hg) print 'collecting training distributions ...\n' train_distri_list = collect_train_valid_data(train_distri_path, train_distri_path_2, hg, opt, is_train=True) print 'collecting validation distributions ...\n' val_distri_list = collect_train_valid_data(val_distri_path, val_distri_path_2, hg, opt, is_train=False)
Example #4
Source File: collect-scale-ditri.py From pose-adv-aug with Apache License 2.0 | 4 votes |
def main(): opt = TrainOptions().parse() if opt.sr_dir == '': print('sr directory is null.') exit() sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id, opt.sr_dir+'-'+opt.load_prefix_pose[0:-1]) if not os.path.isdir(sr_pretrain_dir): os.makedirs(sr_pretrain_dir) # train_history = ASNTrainHistory() # print(train_history.lr) # exit() checkpoint_hg = Checkpoint() # visualizer = Visualizer(opt) # log_name = opt.resume_prefix_pose + 'log.txt' # visualizer.log_path = sr_pretrain_dir + '/' + log_name train_distri_path = sr_pretrain_dir + '/' + 'train_scales.txt' train_distri_path_2 = sr_pretrain_dir + '/' + 'train_scales_copy.txt' # train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt' # train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt' val_distri_path = sr_pretrain_dir + '/' + 'val_scales.txt' val_distri_path_2 = sr_pretrain_dir + '/' + 'val_scales_copy.txt' # val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt' # val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt' if opt.dataset == 'mpii': num_classes = 16 os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id hg = model.create_hg(num_stacks=2, num_modules=1, num_classes=num_classes, chan=256) hg = torch.nn.DataParallel(hg).cuda() if opt.load_prefix_pose == '': print('please input the checkpoint name of the pose model') exit() # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose) checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.load_prefix_pose)[0:-1] checkpoint_hg.load_checkpoint(hg) print 'collecting training distributions ...\n' train_distri_list = collect_train_valid_data(train_distri_path, train_distri_path_2, hg, opt, is_train=True) print 'collecting validation distributions ...\n' val_distri_list = collect_train_valid_data(val_distri_path, val_distri_path_2, hg, opt, is_train=False)
Example #5
Source File: train.py From SingleGAN with MIT License | 4 votes |
def main(): opt = TrainOptions().parse() data_loader = CreateDataLoader(opt) dataset_size = len(data_loader) * opt.batchSize visualizer = Visualizer(opt) model = SingleGAN() model.initialize(opt) total_steps = 0 lr = opt.lr for epoch in range(1, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() save_result = True for i, data in enumerate(data_loader): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter = total_steps - dataset_size * (epoch - 1) model.update_model(data) if save_result or total_steps % opt.display_freq == 0: save_result = save_result or total_steps % opt.update_html_freq == 0 print('mode:{} dataset:{}'.format(opt.mode,opt.name)) visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result) save_result = False if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) if total_steps % opt.save_latest_freq == 0: print('saving the latest model (epoch %d, total_steps %d)' %(epoch, total_steps)) model.save('latest') if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' %(epoch, total_steps)) model.save('latest') model.save(epoch) if epoch > opt.niter: lr -= opt.lr / opt.niter_decay model.update_lr(lr)