Python trainer.train() Examples
The following are 4
code examples of trainer.train().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
trainer
, or try the search function
.
Example #1
Source File: main.py From VisualizingNDF with MIT License | 5 votes |
def main(): # logging configuration logging.basicConfig( level=logging.INFO, format="[%(asctime)s]: %(levelname)s: %(message)s" ) # command line paser opt = parse.parse_arg() # GPU opt.cuda = opt.gpuid >= 0 if opt.gpuid >= 0: torch.cuda.set_device(opt.gpuid) else: logging.info("WARNING: RUN WITHOUT GPU") # prepare dataset db = dataset.prepare_db(opt) # initalize neural decision forest NDF = model.prepare_model(opt) # prepare optimizer optim, sche = optimizer.prepare_optim(NDF, opt) # train the neural decision forest best_metric = trainer.train(NDF, optim, sche, db, opt) logging.info('The best evaluation metric is %f'%best_metric)
Example #2
Source File: __main__.py From plant-disease-classification with MIT License | 5 votes |
def main(train, classify, help): if (help): print(help_message) sys.exit(0) else: if (train): iteration = click.prompt('Iteration count for training model', type=int) trainer.train(num_iteration=iteration) else: image_file_path = click.prompt('Image file path that is going to be classified', type=str) classifier.classify(file_path=image_file_path)
Example #3
Source File: main.py From tiny-faces-pytorch with MIT License | 4 votes |
def main(): args = arguments() num_templates = 25 # aka the number of clusters normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) img_transforms = transforms.Compose([ transforms.ToTensor(), normalize ]) train_loader, _ = get_dataloader(args.traindata, args, num_templates, img_transforms=img_transforms) model = DetectionModel(num_objects=1, num_templates=num_templates) loss_fn = DetectionCriterion(num_templates) # directory where we'll store model weights weights_dir = "weights" if not osp.exists(weights_dir): os.mkdir(weights_dir) # check for CUDA if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') optimizer = optim.SGD(model.learnable_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) if args.resume: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) # Set the start epoch if it has not been if not args.start_epoch: args.start_epoch = checkpoint['epoch'] scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, last_epoch=args.start_epoch-1) # train and evalute for `epochs` for epoch in range(args.start_epoch, args.epochs): trainer.train(model, loss_fn, optimizer, train_loader, epoch, device=device) scheduler.step() if (epoch+1) % args.save_every == 0: trainer.save_checkpoint({ 'epoch': epoch + 1, 'batch_size': train_loader.batch_size, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename="checkpoint_{0}.pth".format(epoch+1), save_path=weights_dir)
Example #4
Source File: resnet_trainer.py From luna16 with BSD 2-Clause "Simplified" License | 4 votes |
def train(self, generator_train, X_train, generator_val, X_val): #filenames_train, filenames_val = patch_sampling.get_filenames() #generator = partial(patch_sampling.extract_random_patches, patch_size=P.INPUT_SIZE, crop_size=OUTPUT_SIZE) train_true = filter(lambda x: "True" in x, X_train) train_false = filter(lambda x: "False" in x, X_train) print "N train true/false", len(train_true), len(train_false) print X_train[:2] val_true = filter(lambda x: "True" in x, X_val) val_false = filter(lambda x: "False" in x, X_val) n_train_true = len(train_true) n_val_true = len(val_true) logging.info("Starting training...") for epoch in range(P.N_EPOCHS): self.pre_epoch() if epoch in LR_SCHEDULE: logging.info("Setting learning rate to {}".format(LR_SCHEDULE[epoch])) self.l_r.set_value(LR_SCHEDULE[epoch]) np.random.shuffle(train_false) np.random.shuffle(val_false) train_epoch_data = train_true + train_false[:n_train_true] val_epoch_data = val_true + val_false[:n_val_true] np.random.shuffle(train_epoch_data) #np.random.shuffle(val_epoch_data) #Full pass over the training data train_gen = ParallelBatchIterator(generator_train, train_epoch_data, ordered=False, batch_size=P.BATCH_SIZE_TRAIN//3, multiprocess=P.MULTIPROCESS_LOAD_AUGMENTATION, n_producers=P.N_WORKERS_LOAD_AUGMENTATION) self.do_batches(self.train_fn, train_gen, self.train_metrics) # And a full pass over the validation data: val_gen = ParallelBatchIterator(generator_val, val_epoch_data, ordered=False, batch_size=P.BATCH_SIZE_VALIDATION//3, multiprocess=P.MULTIPROCESS_LOAD_AUGMENTATION, n_producers=P.N_WORKERS_LOAD_AUGMENTATION) self.do_batches(self.val_fn, val_gen, self.val_metrics) self.post_epoch()