Python args.get_args() Examples

The following are 25 code examples of args.get_args(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module args , or try the search function .
Example #1
Source File: train_mgpu.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    save_args(args)
    train(args) 
Example #2
Source File: darts_train.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    """
        Start architecture evaluation (retraining from scratch).
    """
    args = get_args()
    print(args)

    ctx = get_extension_context(args.context,
                                device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)
    ext = nn.ext_utils.import_extension_module(args.context)

    assert os.path.exists(
        args.model_arch_name), "architecture's params seem to be missing!"

    ops = {0: dil_conv_3x3, 1: dil_conv_5x5, 2: sep_conv_3x3, 3: sep_conv_5x5,
           4: max_pool_3x3, 5: avg_pool_3x3, 6: identity, 7: zero}

    with open(args.model_arch_name, 'r') as f:
        arch_dict = json.load(f)

    print("Train the model whose architecture is:")
    show_derived_cell(args, ops, arch_dict["arch_normal"], "normal")
    show_derived_cell(args, ops, arch_dict["arch_reduction"], "reduction")
    CNN_run(args, ops, arch_dict)

    return 
Example #3
Source File: datasets.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    if args.command == "create":
        create_pcd_dataset_from_mesh(args.mesh_data_path) 
Example #4
Source File: show_leaned_filters.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def show():
    args = get_args()

    # Load model
    nn.load_parameters(args.model_load_path)
    params = nn.get_parameters()

    # Show heatmap
    for name, param in params.items():
        # SSL only on convolution weights
        if "conv/W" not in name:
            continue
        print(name)
        n, m, k0, k1 = param.d.shape
        w_matrix = param.d.reshape((n, m * k0 * k1))
        # Filter x Channel heatmap

        fig, ax = plt.subplots()
        ax.set_title("{} with shape {} \n Filter x (Channel x Heigh x Width)".format(
            name, (n, m, k0, k1)))
        heatmap = ax.pcolor(w_matrix)
        fig.colorbar(heatmap)

        plt.pause(0.5)
        raw_input("Press Key")
        plt.close() 
Example #5
Source File: convert_tf_nnabla.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    ''' 
    Main

    Usage: python convert_tf_nnabla.py --input-ckpt-file=/path to ckpt file --output-nnabla-file=/output .h5 file

    '''

    # Parse the arguments
    args = get_args()

    # convert the input file(.ckpt) to the output file(.h5)
    convert(args.input_ckpt_file, args.output_nnabla_file) 
Example #6
Source File: eval.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    miou = validate(args) 
Example #7
Source File: prepare_lfw_data.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    '''
    Arguments:
    train-file = txt file containing randomly selected image filenames to be taken as training set.
    val-file = txt file containing randomly selected image filenames to be taken as validation set.
    data-dir = dataset directory
    Usage: python dataset_utils.py --train-file="" --val-file="" --data_dir=""
    '''

    args = get_args()
    data_dir = args.data_dir

    generate_path_files(data_dir, args.train_file, args.val_file) 
Example #8
Source File: plot_accuracy.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)
    iterations = []
    mean_iou = []
    model_dir = args.model_load_path
    for filename in os.listdir(model_dir):
        args.model_load_path = model_dir+filename
        miou = eval.validate(args)
        iterations.append(filename.split('.')[0])
        mean_iou.append(miou)

    for i in range(len(iterations)):
        iterations[i] = iterations[i].replace('param_', '')

    itr = list(map(int, iterations))

    # Plot Iterations Vs mIOU
    plt.axes([0, max(itr), 0.0, 1.0])
    plt.xlabel('Iterations')
    plt.ylabel('Accuracy - mIOU')
    plt.scatter(itr, mean_iou)
    plt.show()

    print(iterations)
    print(mean_iou)
    with open('iterations.txt', 'w') as f:
        for item in iterations:
            f.write('%s\n' % item)
    with open('miou.txt', 'w') as f2:
        for item in mean_iou:
            f2.write('%s\n' % item)

    #plt.plot(iterations, mean_iou) 
Example #9
Source File: train.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    save_args(args)
    train(args) 
Example #10
Source File: train.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    # Context
    extension_module = args.context
    ctx = get_extension_context(
        extension_module, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    train(args) 
Example #11
Source File: interpolate.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    save_args(args, "generate")
    interpolate(args) 
Example #12
Source File: prepare_datasets.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    train = True
    if args.data_type == "train":
        train = True
    elif args.data_type == "val":
        train = False

    prepare_pix2pix_dataset(args.dataset, train) 
Example #13
Source File: generate.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    save_args(args, "generate")
    generate(args) 
Example #14
Source File: generate.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    save_args(args, "generate")

    generate(args) 
Example #15
Source File: morph.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    save_args(args, "morph")

    morph(args) 
Example #16
Source File: train_with_mgpu.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    save_args(args, "train")

    train(args) 
Example #17
Source File: match.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    save_args(args, "match")

    match(args) 
Example #18
Source File: generate.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    save_args(args, "generate")

    generate(args) 
Example #19
Source File: train.py    From nnabla-examples with Apache License 2.0 5 votes vote down vote up
def main():
    args = get_args()
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)
    train(args) 
Example #20
Source File: train.py    From PointFlow with MIT License 5 votes vote down vote up
def main():
    # command line args
    args = get_args()
    save_dir = os.path.join("checkpoints", args.log_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        os.makedirs(os.path.join(save_dir, 'images'))

    with open(os.path.join(save_dir, 'command.sh'), 'w') as f:
        f.write('python -X faulthandler ' + ' '.join(sys.argv))
        f.write('\n')

    if args.seed is None:
        args.seed = random.randint(0, 1000000)
    set_random_seed(args.seed)

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    if args.sync_bn:
        assert args.distributed

    print("Arguments:")
    print(args)

    ngpus_per_node = torch.cuda.device_count()
    if args.distributed:
        args.world_size = ngpus_per_node * args.world_size
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(save_dir, ngpus_per_node, args))
    else:
        main_worker(args.gpu, save_dir, ngpus_per_node, args) 
Example #21
Source File: create_initialized_model.py    From nnabla with Apache License 2.0 4 votes vote down vote up
def main():

    # Read envvar `NNABLA_EXAMPLES_ROOT` to identify the path to your local
    # nnabla-examples directory.
    HERE = os.path.dirname(__file__)
    nnabla_examples_root = os.environ.get('NNABLA_EXAMPLES_ROOT', os.path.join(
        HERE, '../../../../nnabla-examples'))
    mnist_examples_root = os.path.realpath(
        os.path.join(nnabla_examples_root, 'mnist-collection'))
    sys.path.append(mnist_examples_root)
    nnabla_examples_git_url = 'https://github.com/sony/nnabla-examples'

    # Check if nnabla-examples found.
    try:
        from args import get_args
    except ImportError:
        print(
            'An envvar `NNABLA_EXAMPLES_ROOT`'
            ' which locates the local path to '
            '[nnabla-examples]({})'
            ' repository must be set correctly.'.format(
                nnabla_examples_git_url),
            file=sys.stderr)
        raise

    # Import MNIST data
    from mnist_data import data_iterator_mnist
    from classification import mnist_lenet_prediction, mnist_resnet_prediction

    args = get_args(description=__doc__)

    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # Create a computation graph to be saved.
    x = nn.Variable([args.batch_size, 1, 28, 28])
    h = mnist_cnn_prediction(x, test=False, aug=False)
    t = nn.Variable([args.batch_size, 1])
    loss = F.mean(F.softmax_cross_entropy(h, t))
    y = mnist_cnn_prediction(x, test=True, aug=False)

    # Save NNP file (used in C++ inference later.).
    nnp_file = '{}_initialized.nnp'.format(args.net)
    runtime_contents = {
        'networks': [
            {'name': 'training',
             'batch_size': args.batch_size,
             'outputs': {'loss': loss},
             'names': {'x': x, 't': t}},
            {'name': 'runtime',
             'batch_size': args.batch_size,
             'outputs': {'y': y},
             'names': {'x': x}}],
        'executors': [
            {'name': 'runtime',
             'network': 'runtime',
             'data': ['x'],
             'output': ['y']}]}
    nn.utils.save.save(nnp_file, runtime_contents) 
Example #22
Source File: pix2pix.py    From nnabla-examples with Apache License 2.0 4 votes vote down vote up
def main():
    # argparse
    args = get_args()

    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id)
    nn.set_default_context(ctx)

    model_path = args.model

    if args.train:
        # Data Loading
        logger.info("Initialing DataSource.")
        train_iterator = facade.facade_data_iterator(
            args.traindir,
            args.batchsize,
            shuffle=True,
            with_memory_cache=False)
        val_iterator = facade.facade_data_iterator(
            args.valdir,
            args.batchsize,
            random_crop=False,
            shuffle=False,
            with_memory_cache=False)

        monitor = nm.Monitor(args.logdir)
        solver_gen = S.Adam(alpha=args.lrate, beta1=args.beta1)
        solver_dis = S.Adam(alpha=args.lrate, beta1=args.beta1)

        generator = unet.generator
        discriminator = unet.discriminator

        model_path = train(generator, discriminator, args.patch_gan,
                           solver_gen, solver_dis,
                           args.weight_l1, train_iterator, val_iterator,
                           args.epoch, monitor, args.monitor_interval)

    if args.generate:
        if model_path is not None:
            # Data Loading
            logger.info("Generating from DataSource.")
            test_iterator = facade.facade_data_iterator(
                args.testdir,
                args.batchsize,
                shuffle=False,
                with_memory_cache=False)
            generator = unet.generator
            generate(generator, model_path, test_iterator, args.logdir)
        else:
            logger.error("Trained model was NOT given.") 
Example #23
Source File: train.py    From nnabla-examples with Apache License 2.0 4 votes vote down vote up
def main():
    # Args
    args = get_args()
    save_args(args)

    # Context
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    # Data Iterator
    di = data_iterator(args.img_path, args.batch_size,
                       imsize=(args.imsize, args.imsize),
                       num_samples=args.train_samples,
                       dataset_name=args.dataset_name)
    # Model
    generator = Generator(use_bn=args.use_bn, last_act=args.last_act,
                          use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward)
    discriminator = Discriminator(use_ln=args.use_ln, alpha=args.leaky_alpha,
                                  use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward)

    # Solver
    solver_gen = S.Adam(alpha=args.learning_rate,
                        beta1=args.beta1, beta2=args.beta2)
    solver_dis = S.Adam(alpha=args.learning_rate,
                        beta1=args.beta1, beta2=args.beta2)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_dis = MonitorSeries(
        "Discriminator Loss", monitor, interval=10)
    monitor_p_fake = MonitorSeries("Fake Probability", monitor, interval=10)
    monitor_p_real = MonitorSeries("Real Probability", monitor, interval=10)
    monitor_time = MonitorTimeElapsed(
        "Training Time per Resolution", monitor, interval=1)
    monitor_image_tile = MonitorImageTileWithName("Image Tile", monitor,
                                                  num_images=4,
                                                  normalize_method=lambda x: (x + 1.) / 2.)

    # TODO: use argument
    resolution_list = [4, 8, 16, 32, 64, 128]
    channel_list = [512, 512, 256, 128, 64, 32]

    trainer = Trainer(di,
                      generator, discriminator,
                      solver_gen, solver_dis,
                      args.monitor_path,
                      monitor_loss_gen, monitor_loss_dis,
                      monitor_p_fake, monitor_p_real,
                      monitor_time,
                      monitor_image_tile,
                      resolution_list, channel_list,
                      n_latent=args.latent, n_critic=args.critic,
                      save_image_interval=args.save_image_interval,
                      hyper_sphere=args.hyper_sphere,
                      l2_fake_weight=args.l2_fake_weight)

    # TODO: use images per resolution?
    trainer.train(args.epoch_per_resolution) 
Example #24
Source File: validate.py    From nnabla-examples with Apache License 2.0 4 votes vote down vote up
def main():
    # Args
    args = get_args()

    # Context
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    logger.info(ctx)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    # Monitor
    monitor = Monitor(args.monitor_path)

    # Validation
    logger.info("Start validation")

    num_images = args.valid_samples
    num_batches = num_images // args.batch_size

    # DataIterator
    di = data_iterator(args.img_path, args.batch_size,
                       imsize=(args.imsize, args.imsize),
                       num_samples=args.valid_samples,
                       dataset_name=args.dataset_name)
    # generator
    gen = load_gen(args.model_load_path, use_bn=args.use_bn, last_act=args.last_act,
                   use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward)

    # compute metric
    if args.validation_metric == "ms-ssim":
        logger.info("Multi Scale SSIM")
        monitor_time = MonitorTimeElapsed(
            "MS-SSIM-ValidationTime", monitor, interval=1)
        monitor_metric = MonitorSeries("MS-SSIM", monitor, interval=1)
        from ms_ssim import compute_metric
        score = compute_metric(gen, args.batch_size,
                               num_images, args.latent, args.hyper_sphere)
        monitor_time.add(0)
        monitor_metric.add(0, score)
    elif args.validation_metric == "swd":
        logger.info("Sliced Wasserstein Distance")
        monitor_time = MonitorTimeElapsed(
            "SWD-ValidationTime", monitor, interval=1)
        monitor_metric = MonitorSeries("SWD", monitor, interval=1)
        nhoods_per_image = 128
        nhood_size = 7
        level_list = [128, 64, 32, 16]  # TODO: use argument
        dir_repeats = 4
        dirs_per_repeat = 128
        from sliced_wasserstein import compute_metric
        score = compute_metric(di, gen, args.latent, num_batches, nhoods_per_image, nhood_size,
                               level_list, dir_repeats, dirs_per_repeat, args.hyper_sphere)
        monitor_time.add(0)
        monitor_metric.add(0, score)  # averaged in the log
    else:
        logger.info("Set `validation-metric` as either `ms-ssim` or `swd`.")
    logger.info(score)
    logger.info("End validation") 
Example #25
Source File: generate.py    From nnabla-examples with Apache License 2.0 4 votes vote down vote up
def main():
    # Args
    args = get_args()

    # Context
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    # Config
    resolution_list = [4, 8, 16, 32, 64, 128]
    channel_list = [512, 512, 256, 128, 64, 32]
    side = 8

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_image_tile = MonitorImageTileWithName("Image Tile", monitor,
                                                  num_images=side**2)

    # Generate
    # generate tile images
    imgs = []
    for _ in range(side):
        img = generate_images(args.model_load_path,
                              batch_size=side, use_bn=args.use_bn,
                              n_latent=args.latent, hyper_sphere=args.hyper_sphere,
                              last_act=args.last_act,
                              use_wscale=args.not_use_wscale,
                              use_he_backward=args.use_he_backward,
                              resolution_list=resolution_list, channel_list=channel_list)
        imgs.append(img)
    imgs = np.concatenate(imgs, axis=0)
    monitor_image_tile.add("GeneratedImage", imgs)

    # generate interpolated tile images
    imgs = []
    for _ in range(side):
        img = generate_interpolated_images(args.model_load_path,
                                           batch_size=side, use_bn=args.use_bn,
                                           n_latent=args.latent, hyper_sphere=args.hyper_sphere,
                                           last_act=args.last_act,
                                           use_wscale=args.not_use_wscale,
                                           use_he_backward=args.use_he_backward,
                                           resolution_list=resolution_list, channel_list=channel_list)
        imgs.append(img)
    imgs = np.concatenate(imgs, axis=0)
    monitor_image_tile.add("GeneratedInterpolatedImage", imgs)