Python args.get_args() Examples
The following are 25
code examples of args.get_args().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
args
, or try the search function
.
Example #1
Source File: train_mgpu.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args) train(args)
Example #2
Source File: darts_train.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): """ Start architecture evaluation (retraining from scratch). """ args = get_args() print(args) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) ext = nn.ext_utils.import_extension_module(args.context) assert os.path.exists( args.model_arch_name), "architecture's params seem to be missing!" ops = {0: dil_conv_3x3, 1: dil_conv_5x5, 2: sep_conv_3x3, 3: sep_conv_5x5, 4: max_pool_3x3, 5: avg_pool_3x3, 6: identity, 7: zero} with open(args.model_arch_name, 'r') as f: arch_dict = json.load(f) print("Train the model whose architecture is:") show_derived_cell(args, ops, arch_dict["arch_normal"], "normal") show_derived_cell(args, ops, arch_dict["arch_reduction"], "reduction") CNN_run(args, ops, arch_dict) return
Example #3
Source File: datasets.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() if args.command == "create": create_pcd_dataset_from_mesh(args.mesh_data_path)
Example #4
Source File: show_leaned_filters.py From nnabla-examples with Apache License 2.0 | 5 votes |
def show(): args = get_args() # Load model nn.load_parameters(args.model_load_path) params = nn.get_parameters() # Show heatmap for name, param in params.items(): # SSL only on convolution weights if "conv/W" not in name: continue print(name) n, m, k0, k1 = param.d.shape w_matrix = param.d.reshape((n, m * k0 * k1)) # Filter x Channel heatmap fig, ax = plt.subplots() ax.set_title("{} with shape {} \n Filter x (Channel x Heigh x Width)".format( name, (n, m, k0, k1))) heatmap = ax.pcolor(w_matrix) fig.colorbar(heatmap) plt.pause(0.5) raw_input("Press Key") plt.close()
Example #5
Source File: convert_tf_nnabla.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): ''' Main Usage: python convert_tf_nnabla.py --input-ckpt-file=/path to ckpt file --output-nnabla-file=/output .h5 file ''' # Parse the arguments args = get_args() # convert the input file(.ckpt) to the output file(.h5) convert(args.input_ckpt_file, args.output_nnabla_file)
Example #6
Source File: eval.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() rng = np.random.RandomState(1223) # Get context from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) miou = validate(args)
Example #7
Source File: prepare_lfw_data.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): ''' Arguments: train-file = txt file containing randomly selected image filenames to be taken as training set. val-file = txt file containing randomly selected image filenames to be taken as validation set. data-dir = dataset directory Usage: python dataset_utils.py --train-file="" --val-file="" --data_dir="" ''' args = get_args() data_dir = args.data_dir generate_path_files(data_dir, args.train_file, args.val_file)
Example #8
Source File: plot_accuracy.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() rng = np.random.RandomState(1223) # Get context from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) iterations = [] mean_iou = [] model_dir = args.model_load_path for filename in os.listdir(model_dir): args.model_load_path = model_dir+filename miou = eval.validate(args) iterations.append(filename.split('.')[0]) mean_iou.append(miou) for i in range(len(iterations)): iterations[i] = iterations[i].replace('param_', '') itr = list(map(int, iterations)) # Plot Iterations Vs mIOU plt.axes([0, max(itr), 0.0, 1.0]) plt.xlabel('Iterations') plt.ylabel('Accuracy - mIOU') plt.scatter(itr, mean_iou) plt.show() print(iterations) print(mean_iou) with open('iterations.txt', 'w') as f: for item in iterations: f.write('%s\n' % item) with open('miou.txt', 'w') as f2: for item in mean_iou: f2.write('%s\n' % item) #plt.plot(iterations, mean_iou)
Example #9
Source File: train.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args) train(args)
Example #10
Source File: train.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() # Context extension_module = args.context ctx = get_extension_context( extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) train(args)
Example #11
Source File: interpolate.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "generate") interpolate(args)
Example #12
Source File: prepare_datasets.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() train = True if args.data_type == "train": train = True elif args.data_type == "val": train = False prepare_pix2pix_dataset(args.dataset, train)
Example #13
Source File: generate.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "generate") generate(args)
Example #14
Source File: generate.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "generate") generate(args)
Example #15
Source File: morph.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "morph") morph(args)
Example #16
Source File: train_with_mgpu.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "train") train(args)
Example #17
Source File: match.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "match") match(args)
Example #18
Source File: generate.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "generate") generate(args)
Example #19
Source File: train.py From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) train(args)
Example #20
Source File: train.py From PointFlow with MIT License | 5 votes |
def main(): # command line args args = get_args() save_dir = os.path.join("checkpoints", args.log_name) if not os.path.exists(save_dir): os.makedirs(save_dir) os.makedirs(os.path.join(save_dir, 'images')) with open(os.path.join(save_dir, 'command.sh'), 'w') as f: f.write('python -X faulthandler ' + ' '.join(sys.argv)) f.write('\n') if args.seed is None: args.seed = random.randint(0, 1000000) set_random_seed(args.seed) if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) if args.sync_bn: assert args.distributed print("Arguments:") print(args) ngpus_per_node = torch.cuda.device_count() if args.distributed: args.world_size = ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=ngpus_per_node, args=(save_dir, ngpus_per_node, args)) else: main_worker(args.gpu, save_dir, ngpus_per_node, args)
Example #21
Source File: create_initialized_model.py From nnabla with Apache License 2.0 | 4 votes |
def main(): # Read envvar `NNABLA_EXAMPLES_ROOT` to identify the path to your local # nnabla-examples directory. HERE = os.path.dirname(__file__) nnabla_examples_root = os.environ.get('NNABLA_EXAMPLES_ROOT', os.path.join( HERE, '../../../../nnabla-examples')) mnist_examples_root = os.path.realpath( os.path.join(nnabla_examples_root, 'mnist-collection')) sys.path.append(mnist_examples_root) nnabla_examples_git_url = 'https://github.com/sony/nnabla-examples' # Check if nnabla-examples found. try: from args import get_args except ImportError: print( 'An envvar `NNABLA_EXAMPLES_ROOT`' ' which locates the local path to ' '[nnabla-examples]({})' ' repository must be set correctly.'.format( nnabla_examples_git_url), file=sys.stderr) raise # Import MNIST data from mnist_data import data_iterator_mnist from classification import mnist_lenet_prediction, mnist_resnet_prediction args = get_args(description=__doc__) mnist_cnn_prediction = mnist_lenet_prediction if args.net == 'resnet': mnist_cnn_prediction = mnist_resnet_prediction # Create a computation graph to be saved. x = nn.Variable([args.batch_size, 1, 28, 28]) h = mnist_cnn_prediction(x, test=False, aug=False) t = nn.Variable([args.batch_size, 1]) loss = F.mean(F.softmax_cross_entropy(h, t)) y = mnist_cnn_prediction(x, test=True, aug=False) # Save NNP file (used in C++ inference later.). nnp_file = '{}_initialized.nnp'.format(args.net) runtime_contents = { 'networks': [ {'name': 'training', 'batch_size': args.batch_size, 'outputs': {'loss': loss}, 'names': {'x': x, 't': t}}, {'name': 'runtime', 'batch_size': args.batch_size, 'outputs': {'y': y}, 'names': {'x': x}}], 'executors': [ {'name': 'runtime', 'network': 'runtime', 'data': ['x'], 'output': ['y']}]} nn.utils.save.save(nnp_file, runtime_contents)
Example #22
Source File: pix2pix.py From nnabla-examples with Apache License 2.0 | 4 votes |
def main(): # argparse args = get_args() # Context Setting # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context( args.context, device_id=args.device_id) nn.set_default_context(ctx) model_path = args.model if args.train: # Data Loading logger.info("Initialing DataSource.") train_iterator = facade.facade_data_iterator( args.traindir, args.batchsize, shuffle=True, with_memory_cache=False) val_iterator = facade.facade_data_iterator( args.valdir, args.batchsize, random_crop=False, shuffle=False, with_memory_cache=False) monitor = nm.Monitor(args.logdir) solver_gen = S.Adam(alpha=args.lrate, beta1=args.beta1) solver_dis = S.Adam(alpha=args.lrate, beta1=args.beta1) generator = unet.generator discriminator = unet.discriminator model_path = train(generator, discriminator, args.patch_gan, solver_gen, solver_dis, args.weight_l1, train_iterator, val_iterator, args.epoch, monitor, args.monitor_interval) if args.generate: if model_path is not None: # Data Loading logger.info("Generating from DataSource.") test_iterator = facade.facade_data_iterator( args.testdir, args.batchsize, shuffle=False, with_memory_cache=False) generator = unet.generator generate(generator, model_path, test_iterator, args.logdir) else: logger.error("Trained model was NOT given.")
Example #23
Source File: train.py From nnabla-examples with Apache License 2.0 | 4 votes |
def main(): # Args args = get_args() save_args(args) # Context ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) nn.set_auto_forward(True) # Data Iterator di = data_iterator(args.img_path, args.batch_size, imsize=(args.imsize, args.imsize), num_samples=args.train_samples, dataset_name=args.dataset_name) # Model generator = Generator(use_bn=args.use_bn, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward) discriminator = Discriminator(use_ln=args.use_ln, alpha=args.leaky_alpha, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward) # Solver solver_gen = S.Adam(alpha=args.learning_rate, beta1=args.beta1, beta2=args.beta2) solver_dis = S.Adam(alpha=args.learning_rate, beta1=args.beta1, beta2=args.beta2) # Monitor monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10) monitor_loss_dis = MonitorSeries( "Discriminator Loss", monitor, interval=10) monitor_p_fake = MonitorSeries("Fake Probability", monitor, interval=10) monitor_p_real = MonitorSeries("Real Probability", monitor, interval=10) monitor_time = MonitorTimeElapsed( "Training Time per Resolution", monitor, interval=1) monitor_image_tile = MonitorImageTileWithName("Image Tile", monitor, num_images=4, normalize_method=lambda x: (x + 1.) / 2.) # TODO: use argument resolution_list = [4, 8, 16, 32, 64, 128] channel_list = [512, 512, 256, 128, 64, 32] trainer = Trainer(di, generator, discriminator, solver_gen, solver_dis, args.monitor_path, monitor_loss_gen, monitor_loss_dis, monitor_p_fake, monitor_p_real, monitor_time, monitor_image_tile, resolution_list, channel_list, n_latent=args.latent, n_critic=args.critic, save_image_interval=args.save_image_interval, hyper_sphere=args.hyper_sphere, l2_fake_weight=args.l2_fake_weight) # TODO: use images per resolution? trainer.train(args.epoch_per_resolution)
Example #24
Source File: validate.py From nnabla-examples with Apache License 2.0 | 4 votes |
def main(): # Args args = get_args() # Context ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) logger.info(ctx) nn.set_default_context(ctx) nn.set_auto_forward(True) # Monitor monitor = Monitor(args.monitor_path) # Validation logger.info("Start validation") num_images = args.valid_samples num_batches = num_images // args.batch_size # DataIterator di = data_iterator(args.img_path, args.batch_size, imsize=(args.imsize, args.imsize), num_samples=args.valid_samples, dataset_name=args.dataset_name) # generator gen = load_gen(args.model_load_path, use_bn=args.use_bn, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward) # compute metric if args.validation_metric == "ms-ssim": logger.info("Multi Scale SSIM") monitor_time = MonitorTimeElapsed( "MS-SSIM-ValidationTime", monitor, interval=1) monitor_metric = MonitorSeries("MS-SSIM", monitor, interval=1) from ms_ssim import compute_metric score = compute_metric(gen, args.batch_size, num_images, args.latent, args.hyper_sphere) monitor_time.add(0) monitor_metric.add(0, score) elif args.validation_metric == "swd": logger.info("Sliced Wasserstein Distance") monitor_time = MonitorTimeElapsed( "SWD-ValidationTime", monitor, interval=1) monitor_metric = MonitorSeries("SWD", monitor, interval=1) nhoods_per_image = 128 nhood_size = 7 level_list = [128, 64, 32, 16] # TODO: use argument dir_repeats = 4 dirs_per_repeat = 128 from sliced_wasserstein import compute_metric score = compute_metric(di, gen, args.latent, num_batches, nhoods_per_image, nhood_size, level_list, dir_repeats, dirs_per_repeat, args.hyper_sphere) monitor_time.add(0) monitor_metric.add(0, score) # averaged in the log else: logger.info("Set `validation-metric` as either `ms-ssim` or `swd`.") logger.info(score) logger.info("End validation")
Example #25
Source File: generate.py From nnabla-examples with Apache License 2.0 | 4 votes |
def main(): # Args args = get_args() # Context ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) nn.set_auto_forward(True) # Config resolution_list = [4, 8, 16, 32, 64, 128] channel_list = [512, 512, 256, 128, 64, 32] side = 8 # Monitor monitor = Monitor(args.monitor_path) monitor_image_tile = MonitorImageTileWithName("Image Tile", monitor, num_images=side**2) # Generate # generate tile images imgs = [] for _ in range(side): img = generate_images(args.model_load_path, batch_size=side, use_bn=args.use_bn, n_latent=args.latent, hyper_sphere=args.hyper_sphere, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward, resolution_list=resolution_list, channel_list=channel_list) imgs.append(img) imgs = np.concatenate(imgs, axis=0) monitor_image_tile.add("GeneratedImage", imgs) # generate interpolated tile images imgs = [] for _ in range(side): img = generate_interpolated_images(args.model_load_path, batch_size=side, use_bn=args.use_bn, n_latent=args.latent, hyper_sphere=args.hyper_sphere, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward, resolution_list=resolution_list, channel_list=channel_list) imgs.append(img) imgs = np.concatenate(imgs, axis=0) monitor_image_tile.add("GeneratedInterpolatedImage", imgs)