Python utils.utils.get_args() Examples
The following are 4
code examples of utils.utils.get_args().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
utils.utils
, or try the search function
.
Example #1
Source File: main.py From Keras-Project-Template with Apache License 2.0 | 6 votes |
def main(): # capture the config path from the run arguments # then process the json configuration file try: args = get_args() config = process_config(args.config) except: print("missing or invalid arguments") exit(0) # create the experiments dirs create_dirs([config.callbacks.tensorboard_log_dir, config.callbacks.checkpoint_dir]) print('Create the data generator.') data_loader = SimpleMnistDataLoader(config) print('Create the model.') model = SimpleMnistModel(config) print('Create the trainer') trainer = SimpleMnistModelTrainer(model.model, data_loader.get_train_data(), config) print('Start training the model.') trainer.train()
Example #2
Source File: main.py From self-supervised-da with MIT License | 5 votes |
def main(): args = get_args() config = process_config(args.config) # create the experiments dirs create_dirs([config.cache_dir, config.model_dir, config.log_dir, config.img_dir]) # logging to the file and stdout logger = get_logger(config.log_dir, config.exp_name) # fix random seed to reproduce results random.seed(config.random_seed) logger.info('Random seed: {:d}'.format(config.random_seed)) if config.method in ['src', 'jigsaw', 'rotate']: model = AuxModel(config, logger) else: raise ValueError("Unknown method: %s" % config.method) src_loader, val_loader = get_train_val_dataloader(config.datasets.src) test_loader = get_test_dataloader(config.datasets.test) tar_loader = None if config.datasets.get('tar', None): tar_loader = get_target_dataloader(config.datasets.tar) if config.mode == 'train': model.train(src_loader, tar_loader, val_loader, test_loader) elif config.mode == 'test': model.test(test_loader)
Example #3
Source File: example.py From Tensorflow-Project-Template with Apache License 2.0 | 5 votes |
def main(): # capture the config path from the run arguments # then process the json configuration file try: args = get_args() config = process_config(args.config) except: print("missing or invalid arguments") exit(0) # create the experiments dirs create_dirs([config.summary_dir, config.checkpoint_dir]) # create tensorflow session sess = tf.Session() # create your data generator data = DataGenerator(config) # create an instance of the model you want model = ExampleModel(config) # create tensorboard logger logger = Logger(sess, config) # create trainer and pass all the previous components to it trainer = ExampleTrainer(sess, model, data, config, logger) #load model if exists model.load(sess) # here you train your model trainer.train()
Example #4
Source File: task.py From Distributed-Tensorflow-Template with MIT License | 5 votes |
def init() -> None: """ The main function of the project used to initialise all the required classes used when training the model """ # get input arguments args = get_args() # get static config information config = process_config() # combine both into dictionary config = {**config, **args} # initialise model model = RawModel(config) # create your data generators for each mode train_data = TFRecordDataLoader(config, mode="train") val_data = TFRecordDataLoader(config, mode="val") test_data = TFRecordDataLoader(config, mode="test") # initialise the estimator trainer = RawTrainer(config, model, train_data, val_data, test_data) # start training trainer.run()