Python trainer.train() Examples

The following are 4 code examples of trainer.train(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module trainer , or try the search function .
Example #1
Source File: main.py    From VisualizingNDF with MIT License 5 votes vote down vote up
def main():
    # logging configuration
    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s]: %(levelname)s: %(message)s"
    )
    
    # command line paser
    opt = parse.parse_arg()

    # GPU
    opt.cuda = opt.gpuid >= 0
    if opt.gpuid >= 0:
        torch.cuda.set_device(opt.gpuid)
    else:
        logging.info("WARNING: RUN WITHOUT GPU")
    
    # prepare dataset    
    db = dataset.prepare_db(opt)
    
    # initalize neural decision forest
    NDF = model.prepare_model(opt)
    
    # prepare optimizer
    optim, sche = optimizer.prepare_optim(NDF, opt)
    
    # train the neural decision forest
    best_metric = trainer.train(NDF, optim, sche, db, opt)
    logging.info('The best evaluation metric is %f'%best_metric) 
Example #2
Source File: __main__.py    From plant-disease-classification with MIT License 5 votes vote down vote up
def main(train, classify, help):
    if (help):
        print(help_message)
        sys.exit(0)
    else:
        if (train):
            iteration = click.prompt('Iteration count for training model', type=int)
            trainer.train(num_iteration=iteration)
        else:
            image_file_path = click.prompt('Image file path that is going to be classified', type=str)
            classifier.classify(file_path=image_file_path) 
Example #3
Source File: main.py    From tiny-faces-pytorch with MIT License 4 votes vote down vote up
def main():
    args = arguments()

    num_templates = 25  # aka the number of clusters

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    img_transforms = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    train_loader, _ = get_dataloader(args.traindata, args, num_templates,
                                     img_transforms=img_transforms)

    model = DetectionModel(num_objects=1, num_templates=num_templates)
    loss_fn = DetectionCriterion(num_templates)

    # directory where we'll store model weights
    weights_dir = "weights"
    if not osp.exists(weights_dir):
        os.mkdir(weights_dir)

    # check for CUDA
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    optimizer = optim.SGD(model.learnable_parameters(args.lr), lr=args.lr,
                          momentum=args.momentum, weight_decay=args.weight_decay)
    # optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        # Set the start epoch if it has not been
        if not args.start_epoch:
            args.start_epoch = checkpoint['epoch']

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=20,
                                          last_epoch=args.start_epoch-1)

    # train and evalute for `epochs`
    for epoch in range(args.start_epoch, args.epochs):
        trainer.train(model, loss_fn, optimizer, train_loader, epoch, device=device)
        scheduler.step()

        if (epoch+1) % args.save_every == 0:
            trainer.save_checkpoint({
                'epoch': epoch + 1,
                'batch_size': train_loader.batch_size,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, filename="checkpoint_{0}.pth".format(epoch+1), save_path=weights_dir) 
Example #4
Source File: resnet_trainer.py    From luna16 with BSD 2-Clause "Simplified" License 4 votes vote down vote up
def train(self, generator_train, X_train, generator_val, X_val):
        #filenames_train, filenames_val = patch_sampling.get_filenames()
        #generator = partial(patch_sampling.extract_random_patches, patch_size=P.INPUT_SIZE, crop_size=OUTPUT_SIZE)


        train_true = filter(lambda x: "True" in x, X_train)
        train_false = filter(lambda x: "False" in x, X_train)

        print "N train true/false", len(train_true), len(train_false)
        print X_train[:2]

        val_true = filter(lambda x: "True" in x, X_val)
        val_false = filter(lambda x: "False" in x, X_val)

        n_train_true = len(train_true)
        n_val_true = len(val_true)

        logging.info("Starting training...")
        for epoch in range(P.N_EPOCHS):
            self.pre_epoch()

            if epoch in LR_SCHEDULE:
                logging.info("Setting learning rate to {}".format(LR_SCHEDULE[epoch]))
                self.l_r.set_value(LR_SCHEDULE[epoch])


            np.random.shuffle(train_false)
            np.random.shuffle(val_false)

            train_epoch_data = train_true + train_false[:n_train_true]
            val_epoch_data = val_true + val_false[:n_val_true]

            np.random.shuffle(train_epoch_data)
            #np.random.shuffle(val_epoch_data)

            #Full pass over the training data
            train_gen = ParallelBatchIterator(generator_train, train_epoch_data, ordered=False,
                                                batch_size=P.BATCH_SIZE_TRAIN//3,
                                                multiprocess=P.MULTIPROCESS_LOAD_AUGMENTATION,
                                                n_producers=P.N_WORKERS_LOAD_AUGMENTATION)

            self.do_batches(self.train_fn, train_gen, self.train_metrics)

            # And a full pass over the validation data:
            val_gen = ParallelBatchIterator(generator_val, val_epoch_data, ordered=False,
                                                batch_size=P.BATCH_SIZE_VALIDATION//3,
                                                multiprocess=P.MULTIPROCESS_LOAD_AUGMENTATION,
                                                n_producers=P.N_WORKERS_LOAD_AUGMENTATION)

            self.do_batches(self.val_fn, val_gen, self.val_metrics)
            self.post_epoch()