Python options.train_options.TrainOptions() Examples

The following are 5 code examples of options.train_options.TrainOptions(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module options.train_options , or try the search function .
Example #1
Source File: train.py    From GANimation with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train() 
Example #2
Source File: train.py    From DMIT with MIT License 5 votes vote down vote up
def main():
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset_size = len(data_loader) * opt.batch_size
    visualizer = Visualizer(opt)
    model = create_model(opt)    
    start_epoch = model.start_epoch
    total_steps = start_epoch*dataset_size
    for epoch in range(start_epoch+1, opt.niter+opt.niter_decay+1):
        epoch_start_time = time.time()
        model.update_lr()
        save_result = True
        for i, data in enumerate(data_loader):
            iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter = total_steps - dataset_size * (epoch - 1)
            model.prepare_data(data)
            model.update_model()
            if save_result or total_steps % opt.display_freq == 0:
                save_result = save_result or total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result)
                save_result = False
            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
        print('epoch {} cost dime {}'.format(epoch,time.time()-epoch_start_time))
        model.save_ckpt(epoch)
        model.save_generator('latest')
        if epoch % opt.save_epoch_freq == 0:
            print('saving the generator at the end of epoch {}, iters {}'.format(epoch, total_steps))
            model.save_generator(epoch) 
Example #3
Source File: collect-rotation-ditri.py    From pose-adv-aug with Apache License 2.0 4 votes vote down vote up
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id,
                                   opt.sr_dir+'-'+opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_hg = Checkpoint()
    # visualizer = Visualizer(opt)
    # log_name = opt.resume_prefix_pose + 'log.txt'
    # visualizer.log_path = sr_pretrain_dir + '/' + log_name
    train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    # train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    # train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'
    # val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    # val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'

    if opt.dataset == 'mpii':
        num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    hg = model.create_hg(num_stacks=2, num_modules=1,
                         num_classes=num_classes, chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        # exit()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    print 'collecting training distributions ...\n'
    train_distri_list = collect_train_valid_data(train_distri_path,
                                                 train_distri_path_2, hg, opt, is_train=True)

    print 'collecting validation distributions ...\n'
    val_distri_list = collect_train_valid_data(val_distri_path,
                                                val_distri_path_2, hg, opt, is_train=False) 
Example #4
Source File: collect-scale-ditri.py    From pose-adv-aug with Apache License 2.0 4 votes vote down vote up
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(opt.exp_dir, opt.exp_id,
                                   opt.sr_dir+'-'+opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    # train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_hg = Checkpoint()
    # visualizer = Visualizer(opt)
    # log_name = opt.resume_prefix_pose + 'log.txt'
    # visualizer.log_path = sr_pretrain_dir + '/' + log_name
    train_distri_path = sr_pretrain_dir + '/' + 'train_scales.txt'
    train_distri_path_2 = sr_pretrain_dir + '/' + 'train_scales_copy.txt'
    # train_distri_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    # train_distri_path_2 = sr_pretrain_dir + '/' + 'train_rotations_copy.txt'
    val_distri_path = sr_pretrain_dir + '/' + 'val_scales.txt'
    val_distri_path_2 = sr_pretrain_dir + '/' + 'val_scales_copy.txt'
    # val_distri_path = sr_pretrain_dir + '/' + 'val_rotations.txt'
    # val_distri_path_2 = sr_pretrain_dir + '/' + 'val_rotations_copy.txt'

    if opt.dataset == 'mpii':
        num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    hg = model.create_hg(num_stacks=2, num_modules=1,
                         num_classes=num_classes, chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    print 'collecting training distributions ...\n'
    train_distri_list = collect_train_valid_data(train_distri_path,
                                                 train_distri_path_2, hg, opt, is_train=True)

    print 'collecting validation distributions ...\n'
    val_distri_list = collect_train_valid_data(val_distri_path,
                                                val_distri_path_2, hg, opt, is_train=False) 
Example #5
Source File: train.py    From SingleGAN with MIT License 4 votes vote down vote up
def main():
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset_size = len(data_loader) * opt.batchSize
    visualizer = Visualizer(opt)


    model = SingleGAN()
    model.initialize(opt)


    total_steps = 0
    lr = opt.lr
    for epoch in range(1, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        save_result = True
        for i, data in enumerate(data_loader):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter = total_steps - dataset_size * (epoch - 1)
            model.update_model(data)
            
            if save_result or total_steps % opt.display_freq == 0:
                save_result = save_result or total_steps % opt.update_html_freq == 0
                print('mode:{} dataset:{}'.format(opt.mode,opt.name))
                visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result)
                save_result = False
            
            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
                    
            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %(epoch, total_steps))
                model.save('latest')
                
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %(epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            
        if epoch > opt.niter:
            lr -= opt.lr / opt.niter_decay
            model.update_lr(lr)