Python utils.utils.get_args() Examples

The following are 4 code examples of utils.utils.get_args(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module utils.utils , or try the search function .
Example #1
Source File: main.py    From Keras-Project-Template with Apache License 2.0 6 votes vote down vote up
def main():
    # capture the config path from the run arguments
    # then process the json configuration file
    try:
        args = get_args()
        config = process_config(args.config)
    except:
        print("missing or invalid arguments")
        exit(0)

    # create the experiments dirs
    create_dirs([config.callbacks.tensorboard_log_dir, config.callbacks.checkpoint_dir])

    print('Create the data generator.')
    data_loader = SimpleMnistDataLoader(config)

    print('Create the model.')
    model = SimpleMnistModel(config)

    print('Create the trainer')
    trainer = SimpleMnistModelTrainer(model.model, data_loader.get_train_data(), config)

    print('Start training the model.')
    trainer.train() 
Example #2
Source File: main.py    From self-supervised-da with MIT License 5 votes vote down vote up
def main():
    args = get_args()
    config = process_config(args.config)

    # create the experiments dirs
    create_dirs([config.cache_dir, config.model_dir,
        config.log_dir, config.img_dir])

    # logging to the file and stdout
    logger = get_logger(config.log_dir, config.exp_name)
    
    # fix random seed to reproduce results
    random.seed(config.random_seed)
    logger.info('Random seed: {:d}'.format(config.random_seed))

    if config.method in ['src', 'jigsaw', 'rotate']:
        model = AuxModel(config, logger)
    else:
        raise ValueError("Unknown method: %s" % config.method)

    src_loader, val_loader = get_train_val_dataloader(config.datasets.src)
    test_loader = get_test_dataloader(config.datasets.test)

    tar_loader = None
    if config.datasets.get('tar', None):
        tar_loader = get_target_dataloader(config.datasets.tar)

    if config.mode == 'train':
        model.train(src_loader, tar_loader, val_loader, test_loader)
    elif config.mode == 'test':
        model.test(test_loader) 
Example #3
Source File: example.py    From Tensorflow-Project-Template with Apache License 2.0 5 votes vote down vote up
def main():
    # capture the config path from the run arguments
    # then process the json configuration file
    try:
        args = get_args()
        config = process_config(args.config)

    except:
        print("missing or invalid arguments")
        exit(0)

    # create the experiments dirs
    create_dirs([config.summary_dir, config.checkpoint_dir])
    # create tensorflow session
    sess = tf.Session()
    # create your data generator
    data = DataGenerator(config)
    
    # create an instance of the model you want
    model = ExampleModel(config)
    # create tensorboard logger
    logger = Logger(sess, config)
    # create trainer and pass all the previous components to it
    trainer = ExampleTrainer(sess, model, data, config, logger)
    #load model if exists
    model.load(sess)
    # here you train your model
    trainer.train() 
Example #4
Source File: task.py    From Distributed-Tensorflow-Template with MIT License 5 votes vote down vote up
def init() -> None:
    """
    The main function of the project used to initialise all the required classes
    used when training the model
    """
    # get input arguments
    args = get_args()
    # get static config information
    config = process_config()
    # combine both into dictionary
    config = {**config, **args}

    # initialise model
    model = RawModel(config)
    # create your data generators for each mode
    train_data = TFRecordDataLoader(config, mode="train")

    val_data = TFRecordDataLoader(config, mode="val")

    test_data = TFRecordDataLoader(config, mode="test")

    # initialise the estimator
    trainer = RawTrainer(config, model, train_data, val_data, test_data)

    # start training
    trainer.run()