Python trainer.Trainer() Examples
The following are 30
code examples of trainer.Trainer().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
trainer
, or try the search function
.
Example #1
Source File: classifier.py From ModelFeast with MIT License | 6 votes |
def train(self): assert callable(self.model), "model is not callable!!" assert callable(self.loss), "loss is not callable!!" assert all(callable(met) for met in self.metrics), "metrics is not callable!!" assert "trainer" in self.config, "trainer hasn't been configured!!" assert isinstance(self.data_loader, Iterable), "data_loader is not iterable!!" # the num of classes in dataset must bet the same as model's output if hasattr(self.data_loader, 'classes'): true_classes = len(self.data_loader.classes) model_output = self.config['arch']['args']['n_class'] assert true_classes==model_output, "model分类数为{},可是实际上有{}个类".format( model_output, true_classes) if "name" not in self.config: self.config["name"] = "_".join(self.config["arch"]["type"], self.config["data_loader"]["type"]) self.trainer = Trainer(self.model, self.loss, self.metrics, self.optimizer, resume=self.resume, config=self.config, data_loader=self.data_loader, valid_data_loader=self.valid_data_loader, lr_scheduler=self.lr_scheduler, train_logger=self.train_logger) self.trainer.train()
Example #2
Source File: main.py From JointBERT with Apache License 2.0 | 6 votes |
def main(args): init_logger() set_seed(args) tokenizer = load_tokenizer(args) train_dataset = load_and_cache_examples(args, tokenizer, mode="train") dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev") test_dataset = load_and_cache_examples(args, tokenizer, mode="test") trainer = Trainer(args, train_dataset, dev_dataset, test_dataset) if args.do_train: trainer.train() if args.do_eval: trainer.load_model() trainer.evaluate("test")
Example #3
Source File: main.py From EDSR-PyTorch with MIT License | 6 votes |
def main(): global model if args.data_test == ['video']: from videotester import VideoTester model = model.Model(args, checkpoint) t = VideoTester(args, model, checkpoint) t.test() else: if checkpoint.ok: loader = data.Data(args) _model = model.Model(args, checkpoint) _loss = loss.Loss(args, checkpoint) if not args.test_only else None t = Trainer(args, loader, _model, _loss, checkpoint) while not t.terminate(): t.train() t.test() checkpoint.done()
Example #4
Source File: main.py From KoBERT-NER with Apache License 2.0 | 6 votes |
def main(args): init_logger() set_seed(args) tokenizer = load_tokenizer(args) train_dataset = None dev_dataset = None test_dataset = None if args.do_train or args.do_eval: test_dataset = load_and_cache_examples(args, tokenizer, mode="test") if args.do_train: train_dataset = load_and_cache_examples(args, tokenizer, mode="train") trainer = Trainer(args, train_dataset, dev_dataset, test_dataset) if args.do_train: trainer.train() if args.do_eval: trainer.load_model() trainer.evaluate("test", "eval")
Example #5
Source File: main.py From pointer-network-tensorflow with MIT License | 6 votes |
def main(_): prepare_dirs_and_logger(config) if not config.task.lower().startswith('tsp'): raise Exception("[!] Task should starts with TSP") if config.max_enc_length is None: config.max_enc_length = config.max_data_length if config.max_dec_length is None: config.max_dec_length = config.max_data_length rng = np.random.RandomState(config.random_seed) tf.set_random_seed(config.random_seed) trainer = Trainer(config, rng) save_config(config.model_dir, config) if config.is_train: trainer.train() else: if not config.load_path: raise Exception("[!] You should specify `load_path` to load a pretrained model") trainer.test() tf.logging.info("Run finished.")
Example #6
Source File: main.py From neural-combinatorial-rl-tensorflow with MIT License | 6 votes |
def main(_): prepare_dirs_and_logger(config) if not config.task.lower().startswith('tsp'): raise Exception("[!] Task should starts with TSP") if config.max_enc_length is None: config.max_enc_length = config.max_data_length if config.max_dec_length is None: config.max_dec_length = config.max_data_length rng = np.random.RandomState(config.random_seed) tf.set_random_seed(config.random_seed) trainer = Trainer(config, rng) save_config(config.model_dir, config) if config.is_train: trainer.train() else: if not config.load_path: raise Exception("[!] You should specify `load_path` to load a pretrained model") trainer.test() tf.logging.info("Run finished.")
Example #7
Source File: train.py From pytorch_segmentation with MIT License | 6 votes |
def main(config, resume): train_logger = Logger() # DATA LOADERS train_loader = get_instance(dataloaders, 'train_loader', config) val_loader = get_instance(dataloaders, 'val_loader', config) # MODEL model = get_instance(models, 'arch', config, train_loader.dataset.num_classes) print(f'\n{model}\n') # LOSS loss = getattr(losses, config['loss'])(ignore_index = config['ignore_index']) # TRAINING trainer = Trainer( model=model, loss=loss, resume=resume, config=config, train_loader=train_loader, val_loader=val_loader, train_logger=train_logger) trainer.train()
Example #8
Source File: main.py From EverybodyDanceNow_reproduce_pytorch with MIT License | 6 votes |
def main(is_debug): # configs import os dataset_dir = '../data/face' pose_name = '../data/target/pose.npy' ckpt_dir = '../checkpoints/face' log_dir = '../checkpoints/face/logs' batch_num = 10 batch_size = 10 image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db')) face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48) # 48 for 512-frame, 96 for HD frame data_loader = DataLoader(face_dataset, batch_size=batch_size, drop_last=True, num_workers=4, shuffle=True) generator, discriminator, batch_num = load_models(ckpt_dir, batch_num) if is_debug: trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1) else: trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader) trainer.train(generator, discriminator, batch_num)
Example #9
Source File: main.py From everybody_dance_now_pytorch with GNU Affero General Public License v3.0 | 6 votes |
def main(is_debug): # configs dataset_dir = '../datasets/cardio_dance_512' pose_name = '../datasets/cardio_dance_512/poses.npy' ckpt_dir = './checkpoints/dance_test_new_down2_res6' log_dir = './logs/dance_test_new_down2_res6' batch_num = 0 batch_size = 64 image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db')) face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48) # 48 for 512-frame, 96 for HD frame data_loader = DataLoader(face_dataset, batch_size=batch_size, drop_last=True, num_workers=4, shuffle=True) generator, discriminator, batch_num = load_models(ckpt_dir, batch_num) if is_debug: trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1) else: trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader) trainer.train(generator, discriminator, batch_num)
Example #10
Source File: gcrn_main.py From gconvRNN with MIT License | 6 votes |
def main(_): #Directory generating.. for saving prepare_dirs(config) #Random seed settings rng = np.random.RandomState(config.random_seed) tf.set_random_seed(config.random_seed) #Model training trainer = Trainer(config, rng) save_config(config.model_dir, config) if config.is_train: trainer.train() else: if not config.load_path: raise Exception( "[!] You should specify `load_path` to " "load a pretrained model") trainer.test()
Example #11
Source File: main.py From conditional-motion-propagation with MIT License | 6 votes |
def main(args): with open(args.config) as f: if version.parse(yaml.version >= "5.1"): config = yaml.load(f, Loader=yaml.FullLoader) else: config = yaml.load(f) for k, v in config.items(): setattr(args, k, v) # exp path if not hasattr(args, 'exp_path'): args.exp_path = os.path.dirname(args.config) # dist init if mp.get_start_method(allow_none=True) != 'spawn': mp.set_start_method('spawn', force=True) dist_init(args.launcher, backend='nccl') # train trainer = Trainer(args) trainer.run()
Example #12
Source File: train.py From DDRNet with MIT License | 6 votes |
def main(args): logging.basicConfig( level=logging.DEBUG, format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s', filename=os.path.join(args.logdir, 'logging.txt')) console = logging.StreamHandler() console.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') console.setFormatter(formatter) logging.getLogger().addHandler(console) filename = os.path.realpath(args.index_file) if not os.path.isfile(filename): raise ValueError('No such index_file: {}'.format(filename)) else: print("Reading csv file: {}".format(filename)) with open(filename, "r") as f: line = f.readline().strip() input_path = line.split(',')[0] if not os.path.exists(input_path): raise ValueError('Input path in csv not exist: {}'.format(input_path)) t = trainer.Trainer(filename, args) t.fit()
Example #13
Source File: main.py From Graph-U-Nets with GNU General Public License v3.0 | 5 votes |
def app_run(args, G_data, fold_idx): G_data.use_fold_data(fold_idx) net = GNet(G_data.feat_dim, G_data.num_class, args) trainer = Trainer(args, net, G_data) trainer.train()
Example #14
Source File: train.py From Real-time-Text-Detection with Apache License 2.0 | 5 votes |
def main(config): train_loader = get_dataloader(config['data_loader']['type'], config['data_loader']['args']) criterion = get_loss(config).cuda() model = get_model(config) trainer = Trainer(config=config, model=model, criterion=criterion, train_loader=train_loader) trainer.train()
Example #15
Source File: main.py From neural-question-generation with MIT License | 5 votes |
def main(args): if args.train: trainer = Trainer(args) trainer.train() else: beamsearcher = BeamSearcher(args.model_path, args.output_dir) beamsearcher.decode()
Example #16
Source File: main.py From densenet with MIT License | 5 votes |
def main(config): # ensure directories are setup prepare_dirs(config) if config.num_gpu > 0: torch.cuda.manual_seed(config.random_seed) kwargs = {'num_workers': 1, 'pin_memory': True} else: torch.manual_seed(config.random_seed) kwargs = {} # instantiate data loaders if config.is_train: data_loader = get_train_valid_loader(config.data_dir, config.dataset, config.batch_size, config.augment, config.random_seed, config.valid_size, config.shuffle, config.show_sample, **kwargs) else: data_loader = get_test_loader(config.data_dir, config.dataset, config.batch_size, config.shuffle, **kwargs) # instantiate trainer trainer = Trainer(config, data_loader) # either train if config.is_train: save_config(config) trainer.train() # or load a pretrained model and test else: trainer.test()
Example #17
Source File: main.py From Deep-Mutual-Learning with MIT License | 5 votes |
def main(config): # ensure directories are setup prepare_dirs(config) # ensure reproducibility #torch.manual_seed(config.random_seed) kwargs = {} if config.use_gpu: #torch.cuda.manual_seed_all(config.random_seed) kwargs = {'num_workers': config.num_workers, 'pin_memory': config.pin_memory} #torch.backends.cudnn.deterministic = True # instantiate data loaders test_data_loader = get_test_loader( config.data_dir, config.batch_size, **kwargs ) if config.is_train: train_data_loader = get_train_loader( config.data_dir, config.batch_size, config.random_seed, config.shuffle, **kwargs ) data_loader = (train_data_loader, test_data_loader) else: data_loader = test_data_loader # instantiate trainer trainer = Trainer(config, data_loader) # either train if config.is_train: save_config(config) trainer.train() # or load a pretrained model and test else: trainer.test()
Example #18
Source File: main.py From ENAS-pytorch with Apache License 2.0 | 5 votes |
def main(args): # pylint:disable=redefined-outer-name """main: Entry point.""" utils.prepare_dirs(args) torch.manual_seed(args.random_seed) if args.num_gpu > 0: torch.cuda.manual_seed(args.random_seed) if args.network_type == 'rnn': dataset = data.text.Corpus(args.data_path) elif args.dataset == 'cifar': dataset = data.image.Image(args.data_path) else: raise NotImplementedError(f"{args.dataset} is not supported") trnr = trainer.Trainer(args, dataset) if args.mode == 'train': utils.save_args(args) trnr.train() elif args.mode == 'derive': assert args.load_path != "", ("`--load_path` should be given in " "`derive` mode") trnr.derive() elif args.mode == 'test': if not args.load_path: raise Exception("[!] You should specify `load_path` to load a " "pretrained model") trnr.test() elif args.mode == 'single': if not args.dag_path: raise Exception("[!] You should specify `dag_path` to load a dag") utils.save_args(args) trnr.train(single=True) else: raise Exception(f"[!] Mode not found: {args.mode}")
Example #19
Source File: train.py From DBNet.pytorch with Apache License 2.0 | 5 votes |
def main(config): import torch from models import build_model, build_loss from data_loader import get_dataloader from trainer import Trainer from post_processing import get_post_processing from utils import get_metric if torch.cuda.device_count() > 1: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=torch.cuda.device_count(), rank=args.local_rank) config['distributed'] = True else: config['distributed'] = False config['local_rank'] = args.local_rank train_loader = get_dataloader(config['dataset']['train'], config['distributed']) assert train_loader is not None if 'validate' in config['dataset']: validate_loader = get_dataloader(config['dataset']['validate'], False) else: validate_loader = None criterion = build_loss(config['loss']).cuda() config['arch']['backbone']['in_channels'] = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1 model = build_model(config['arch']) post_p = get_post_processing(config['post_processing']) metric = get_metric(config['metric']) trainer = Trainer(config=config, model=model, criterion=criterion, train_loader=train_loader, post_process=post_p, metric_cls=metric, validate_loader=validate_loader) trainer.train()
Example #20
Source File: main.py From BEGAN-tensorflow with Apache License 2.0 | 5 votes |
def main(config): prepare_dirs_and_logger(config) rng = np.random.RandomState(config.random_seed) tf.set_random_seed(config.random_seed) if config.is_train: data_path = config.data_path batch_size = config.batch_size do_shuffle = True else: setattr(config, 'batch_size', 64) if config.test_data_path is None: data_path = config.data_path else: data_path = config.test_data_path batch_size = config.sample_per_image do_shuffle = False data_loader = get_loader( data_path, config.batch_size, config.input_scale_size, config.data_format, config.split) trainer = Trainer(config, data_loader) if config.is_train: save_config(config) trainer.train() else: if not config.load_path: raise Exception("[!] You should specify `load_path` to load a pretrained model") trainer.test()
Example #21
Source File: main.py From recurrent-visual-attention with MIT License | 5 votes |
def main(config): utils.prepare_dirs(config) # ensure reproducibility torch.manual_seed(config.random_seed) kwargs = {} if config.use_gpu: torch.cuda.manual_seed(config.random_seed) kwargs = {"num_workers": 1, "pin_memory": True} # instantiate data loaders if config.is_train: dloader = data_loader.get_train_valid_loader( config.data_dir, config.batch_size, config.random_seed, config.valid_size, config.shuffle, config.show_sample, **kwargs, ) else: dloader = data_loader.get_test_loader( config.data_dir, config.batch_size, **kwargs, ) trainer = Trainer(config, dloader) # either train if config.is_train: utils.save_config(config) trainer.train() # or load a pretrained model and test else: trainer.test()
Example #22
Source File: main.py From simulated-unsupervised-tensorflow with Apache License 2.0 | 5 votes |
def main(_): prepare_dirs(config) rng = np.random.RandomState(config.random_seed) tf.set_random_seed(config.random_seed) trainer = Trainer(config, rng) save_config(config.model_dir, config) if config.is_train: trainer.train() else: if not config.load_path: raise Exception("[!] You should specify `load_path` to load a pretrained model") trainer.test()
Example #23
Source File: main.py From video_captioning_rl with MIT License | 5 votes |
def main(args): prepare_dirs(args) torch.manual_seed(args.random_seed) if args.num_gpu > 0: torch.cuda.manual_seed(args.random_seed) if args.network_type == 'seq2seq': vocab = data.common_loader.Vocab(args.vocab_file, args.max_vocab_size) dataset = {} if args.dataset == 'msrvtt': dataset['train'] = data.common_loader.MSRVTTBatcher(args, 'train', vocab) dataset['val'] = data.common_loader.MSRVTTBatcher(args, 'val', vocab) dataset['test'] = data.common_loader.MSRVTTBatcher(args, 'test', vocab) else: raise Exception(f"Unknown dataset: {args.dataset} for the corresponding network type: {args.network_type}") else: raise NotImplemented(f"{args.dataset} is not supported") trainer = Trainer(args, dataset) if args.mode == 'train': save_args(args) trainer.train() else: if not args.load_path: raise Exception("[!] You should specify `load_path` to load a pretrained model") else: trainer.test(args.mode)
Example #24
Source File: main.py From HistoGAN with GNU General Public License v3.0 | 5 votes |
def main(config): prepare_dirs_and_logger(config) torch.manual_seed(config.random_seed) if config.num_gpu > 0: torch.cuda.manual_seed(config.random_seed) if config.is_train: data_path = config.data_path batch_size = config.batch_size else: if config.test_data_path is None: data_path = config.data_path else: data_path = config.test_data_path batch_size = config.sample_per_image a_data_loader, b_data_loader = get_loader( data_path, batch_size, config.input_scale_size, config.num_worker, config.skip_pix2pix_processing) trainer = Trainer(config, a_data_loader, b_data_loader) if config.is_train: save_config(config) trainer.train() else: if not config.load_path: raise Exception("[!] You should specify `load_path` to load a pretrained model") trainer.test()
Example #25
Source File: main.py From R-BERT with Apache License 2.0 | 5 votes |
def main(args): init_logger() tokenizer = load_tokenizer(args) train_dataset = load_and_cache_examples(args, tokenizer, mode="train") test_dataset = load_and_cache_examples(args, tokenizer, mode="test") trainer = Trainer(args, train_dataset=train_dataset, test_dataset=test_dataset) if args.do_train: trainer.train() if args.do_eval: trainer.load_model() trainer.evaluate('test')
Example #26
Source File: convert.py From ZeroSpeech-TTS-without-T with MIT License | 5 votes |
def get_trainer(hps_path, model_path, g_mode, enc_mode, clf_path): HPS = Hps(hps_path) hps = HPS.get_tuple() global MIN_LEN MIN_LEN = MIN_LEN if hps.enc_mode != 'gumbel_t' else hps.seg_len trainer = Trainer(hps, None, g_mode, enc_mode) trainer.load_model(model_path, load_model_list=hps.load_model_list, clf_path = clf_path) return trainer
Example #27
Source File: train.py From pytorch-template with MIT License | 5 votes |
def main(config): logger = config.get_logger('train') # setup data_loader instances data_loader = config.init_obj('data_loader', module_data) valid_data_loader = data_loader.split_validation() # build model architecture, then print to console model = config.init_obj('arch', module_arch) logger.info(model) # get function handles of loss and metrics criterion = getattr(module_loss, config['loss']) metrics = [getattr(module_metric, met) for met in config['metrics']] # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler trainable_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = config.init_obj('optimizer', torch.optim, trainable_params) lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) trainer = Trainer(model, criterion, metrics, optimizer, config=config, data_loader=data_loader, valid_data_loader=valid_data_loader, lr_scheduler=lr_scheduler) trainer.train()
Example #28
Source File: main.py From BigGAN-pytorch with Apache License 2.0 | 5 votes |
def main(config): # For fast training cudnn.benchmark = True config.n_class = len(glob.glob(os.path.join(config.image_path, '*/'))) print('number class:', config.n_class) # Data loader data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize, config.batch_size, shuf=config.train) # Create directories if not exist make_folder(config.model_save_path, config.version) make_folder(config.sample_path, config.version) make_folder(config.log_path, config.version) make_folder(config.attn_path, config.version) print('config data_loader and build logs folder') if config.train: if config.model=='sagan': trainer = Trainer(data_loader.loader(), config) elif config.model == 'qgan': trainer = qgan_trainer(data_loader.loader(), config) trainer.train() else: tester = Tester(data_loader.loader(), config) tester.test()
Example #29
Source File: main.py From CausalGAN with MIT License | 5 votes |
def get_trainer(config): print('tf: resetting default graph!') tf.reset_default_graph() #tf.set_random_seed(config.random_seed) #np.random.seed(22) print('Using data_type ',config.data_type) trainer=Trainer(config,config.data_type) print('built trainer successfully') tf.logging.set_verbosity(tf.logging.ERROR) return trainer
Example #30
Source File: train.py From crnn.gluon with Apache License 2.0 | 4 votes |
def main(config): from mxnet import nd from mxnet.gluon.loss import CTCLoss from models import get_model from data_loader import get_dataloader from trainer import Trainer from utils import get_ctx, load if os.path.isfile(config['dataset']['alphabet']): config['dataset']['alphabet'] = ''.join(load(config['dataset']['alphabet'])) prediction_type = config['arch']['args']['prediction']['type'] num_class = len(config['dataset']['alphabet']) # loss 设置 if prediction_type == 'CTC': criterion = CTCLoss() else: raise NotImplementedError ctx = get_ctx(config['trainer']['gpus']) model = get_model(num_class, ctx, config['arch']['args']) model.hybridize() model.initialize(ctx=ctx) img_h, img_w = 32, 100 for process in config['dataset']['train']['dataset']['args']['pre_processes']: if process['type'] == "Resize": img_h = process['args']['img_h'] img_w = process['args']['img_w'] break img_channel = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1 sample_input = nd.zeros((2, img_channel, img_h, img_w), ctx[0]) num_label = model.get_batch_max_length(sample_input) train_loader = get_dataloader(config['dataset']['train'], num_label, config['dataset']['alphabet']) assert train_loader is not None if 'validate' in config['dataset']: validate_loader = get_dataloader(config['dataset']['validate'], num_label, config['dataset']['alphabet']) else: validate_loader = None config['lr_scheduler']['args']['step'] *= len(train_loader) trainer = Trainer(config=config, model=model, criterion=criterion, train_loader=train_loader, validate_loader=validate_loader, sample_input=sample_input, ctx=ctx) trainer.train()