Python args.get_args() Examples
The following are 25
code examples of args.get_args().
Example #1
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args) train(args)
Example #2
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): """ Start architecture evaluation (retraining from scratch). """ args = get_args() print(args) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) ext = nn.ext_utils.import_extension_module(args.context) assert os.path.exists( args.model_arch_name), "architecture's params seem to be missing!" ops = {0: dil_conv_3x3, 1: dil_conv_5x5, 2: sep_conv_3x3, 3: sep_conv_5x5, 4: max_pool_3x3, 5: avg_pool_3x3, 6: identity, 7: zero} with open(args.model_arch_name, 'r') as f: arch_dict = json.load(f) print("Train the model whose architecture is:") show_derived_cell(args, ops, arch_dict["arch_normal"], "normal") show_derived_cell(args, ops, arch_dict["arch_reduction"], "reduction") CNN_run(args, ops, arch_dict) return
Example #3
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() if args.command == "create": create_pcd_dataset_from_mesh(args.mesh_data_path)
Example #4
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def show(): args = get_args() # Load model nn.load_parameters(args.model_load_path) params = nn.get_parameters() # Show heatmap for name, param in params.items(): # SSL only on convolution weights if "conv/W" not in name: continue print(name) n, m, k0, k1 = param.d.shape w_matrix = param.d.reshape((n, m * k0 * k1)) # Filter x Channel heatmap fig, ax = plt.subplots() ax.set_title("{} with shape {} \n Filter x (Channel x Heigh x Width)".format( name, (n, m, k0, k1))) heatmap = ax.pcolor(w_matrix) fig.colorbar(heatmap) plt.pause(0.5) raw_input("Press Key") plt.close()
Example #5
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): ''' Main Usage: python --input-ckpt-file=/path to ckpt file --output-nnabla-file=/output .h5 file ''' # Parse the arguments args = get_args() # convert the input file(.ckpt) to the output file(.h5) convert(args.input_ckpt_file, args.output_nnabla_file)
Example #6
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() rng = np.random.RandomState(1223) # Get context from nnabla.ext_utils import get_extension_context"Running in %s" % args.context) ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) miou = validate(args)
Example #7
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): ''' Arguments: train-file = txt file containing randomly selected image filenames to be taken as training set. val-file = txt file containing randomly selected image filenames to be taken as validation set. data-dir = dataset directory Usage: python --train-file="" --val-file="" --data_dir="" ''' args = get_args() data_dir = args.data_dir generate_path_files(data_dir, args.train_file, args.val_file)
Example #8
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() rng = np.random.RandomState(1223) # Get context from nnabla.ext_utils import get_extension_context"Running in %s" % args.context) ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) iterations = [] mean_iou = [] model_dir = args.model_load_path for filename in os.listdir(model_dir): args.model_load_path = model_dir+filename miou = eval.validate(args) iterations.append(filename.split('.')[0]) mean_iou.append(miou) for i in range(len(iterations)): iterations[i] = iterations[i].replace('param_', '') itr = list(map(int, iterations)) # Plot Iterations Vs mIOU plt.axes([0, max(itr), 0.0, 1.0]) plt.xlabel('Iterations') plt.ylabel('Accuracy - mIOU') plt.scatter(itr, mean_iou) print(iterations) print(mean_iou) with open('iterations.txt', 'w') as f: for item in iterations: f.write('%s\n' % item) with open('miou.txt', 'w') as f2: for item in mean_iou: f2.write('%s\n' % item) #plt.plot(iterations, mean_iou)
Example #9
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args) train(args)
Example #10
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() # Context extension_module = args.context ctx = get_extension_context( extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) train(args)
Example #11
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "generate") interpolate(args)
Example #12
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() train = True if args.data_type == "train": train = True elif args.data_type == "val": train = False prepare_pix2pix_dataset(args.dataset, train)
Example #13
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "generate") generate(args)
Example #14
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "generate") generate(args)
Example #15
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "morph") morph(args)
Example #16
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "train") train(args)
Example #17
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "match") match(args)
Example #18
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() save_args(args, "generate") generate(args)
Example #19
Source File: From nnabla-examples with Apache License 2.0 | 5 votes |
def main(): args = get_args() ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) train(args)
Example #20
Source File: From PointFlow with MIT License | 5 votes |
def main(): # command line args args = get_args() save_dir = os.path.join("checkpoints", args.log_name) if not os.path.exists(save_dir): os.makedirs(save_dir) os.makedirs(os.path.join(save_dir, 'images')) with open(os.path.join(save_dir, ''), 'w') as f: f.write('python -X faulthandler ' + ' '.join(sys.argv)) f.write('\n') if args.seed is None: args.seed = random.randint(0, 1000000) set_random_seed(args.seed) if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) if args.sync_bn: assert args.distributed print("Arguments:") print(args) ngpus_per_node = torch.cuda.device_count() if args.distributed: args.world_size = ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=ngpus_per_node, args=(save_dir, ngpus_per_node, args)) else: main_worker(args.gpu, save_dir, ngpus_per_node, args)
Example #21
Source File: From nnabla with Apache License 2.0 | 4 votes |
def main(): # Read envvar `NNABLA_EXAMPLES_ROOT` to identify the path to your local # nnabla-examples directory. HERE = os.path.dirname(__file__) nnabla_examples_root = os.environ.get('NNABLA_EXAMPLES_ROOT', os.path.join( HERE, '../../../../nnabla-examples')) mnist_examples_root = os.path.realpath( os.path.join(nnabla_examples_root, 'mnist-collection')) sys.path.append(mnist_examples_root) nnabla_examples_git_url = '' # Check if nnabla-examples found. try: from args import get_args except ImportError: print( 'An envvar `NNABLA_EXAMPLES_ROOT`' ' which locates the local path to ' '[nnabla-examples]({})' ' repository must be set correctly.'.format( nnabla_examples_git_url), file=sys.stderr) raise # Import MNIST data from mnist_data import data_iterator_mnist from classification import mnist_lenet_prediction, mnist_resnet_prediction args = get_args(description=__doc__) mnist_cnn_prediction = mnist_lenet_prediction if == 'resnet': mnist_cnn_prediction = mnist_resnet_prediction # Create a computation graph to be saved. x = nn.Variable([args.batch_size, 1, 28, 28]) h = mnist_cnn_prediction(x, test=False, aug=False) t = nn.Variable([args.batch_size, 1]) loss = F.mean(F.softmax_cross_entropy(h, t)) y = mnist_cnn_prediction(x, test=True, aug=False) # Save NNP file (used in C++ inference later.). nnp_file = '{}_initialized.nnp'.format( runtime_contents = { 'networks': [ {'name': 'training', 'batch_size': args.batch_size, 'outputs': {'loss': loss}, 'names': {'x': x, 't': t}}, {'name': 'runtime', 'batch_size': args.batch_size, 'outputs': {'y': y}, 'names': {'x': x}}], 'executors': [ {'name': 'runtime', 'network': 'runtime', 'data': ['x'], 'output': ['y']}]}, runtime_contents)
Example #22
Source File: From nnabla-examples with Apache License 2.0 | 4 votes |
def main(): # argparse args = get_args() # Context Setting # Get context. from nnabla.ext_utils import get_extension_context"Running in %s" % args.context) ctx = get_extension_context( args.context, device_id=args.device_id) nn.set_default_context(ctx) model_path = args.model if args.train: # Data Loading"Initialing DataSource.") train_iterator = facade.facade_data_iterator( args.traindir, args.batchsize, shuffle=True, with_memory_cache=False) val_iterator = facade.facade_data_iterator( args.valdir, args.batchsize, random_crop=False, shuffle=False, with_memory_cache=False) monitor = nm.Monitor(args.logdir) solver_gen = S.Adam(alpha=args.lrate, beta1=args.beta1) solver_dis = S.Adam(alpha=args.lrate, beta1=args.beta1) generator = unet.generator discriminator = unet.discriminator model_path = train(generator, discriminator, args.patch_gan, solver_gen, solver_dis, args.weight_l1, train_iterator, val_iterator, args.epoch, monitor, args.monitor_interval) if args.generate: if model_path is not None: # Data Loading"Generating from DataSource.") test_iterator = facade.facade_data_iterator( args.testdir, args.batchsize, shuffle=False, with_memory_cache=False) generator = unet.generator generate(generator, model_path, test_iterator, args.logdir) else: logger.error("Trained model was NOT given.")
Example #23
Source File: From nnabla-examples with Apache License 2.0 | 4 votes |
def main(): # Args args = get_args() save_args(args) # Context ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) nn.set_auto_forward(True) # Data Iterator di = data_iterator(args.img_path, args.batch_size, imsize=(args.imsize, args.imsize), num_samples=args.train_samples, dataset_name=args.dataset_name) # Model generator = Generator(use_bn=args.use_bn, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward) discriminator = Discriminator(use_ln=args.use_ln, alpha=args.leaky_alpha, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward) # Solver solver_gen = S.Adam(alpha=args.learning_rate, beta1=args.beta1, beta2=args.beta2) solver_dis = S.Adam(alpha=args.learning_rate, beta1=args.beta1, beta2=args.beta2) # Monitor monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10) monitor_loss_dis = MonitorSeries( "Discriminator Loss", monitor, interval=10) monitor_p_fake = MonitorSeries("Fake Probability", monitor, interval=10) monitor_p_real = MonitorSeries("Real Probability", monitor, interval=10) monitor_time = MonitorTimeElapsed( "Training Time per Resolution", monitor, interval=1) monitor_image_tile = MonitorImageTileWithName("Image Tile", monitor, num_images=4, normalize_method=lambda x: (x + 1.) / 2.) # TODO: use argument resolution_list = [4, 8, 16, 32, 64, 128] channel_list = [512, 512, 256, 128, 64, 32] trainer = Trainer(di, generator, discriminator, solver_gen, solver_dis, args.monitor_path, monitor_loss_gen, monitor_loss_dis, monitor_p_fake, monitor_p_real, monitor_time, monitor_image_tile, resolution_list, channel_list, n_latent=args.latent, n_critic=args.critic, save_image_interval=args.save_image_interval, hyper_sphere=args.hyper_sphere, l2_fake_weight=args.l2_fake_weight) # TODO: use images per resolution? trainer.train(args.epoch_per_resolution)
Example #24
Source File: From nnabla-examples with Apache License 2.0 | 4 votes |
def main(): # Args args = get_args() # Context ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) nn.set_auto_forward(True) # Monitor monitor = Monitor(args.monitor_path) # Validation"Start validation") num_images = args.valid_samples num_batches = num_images // args.batch_size # DataIterator di = data_iterator(args.img_path, args.batch_size, imsize=(args.imsize, args.imsize), num_samples=args.valid_samples, dataset_name=args.dataset_name) # generator gen = load_gen(args.model_load_path, use_bn=args.use_bn, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward) # compute metric if args.validation_metric == "ms-ssim":"Multi Scale SSIM") monitor_time = MonitorTimeElapsed( "MS-SSIM-ValidationTime", monitor, interval=1) monitor_metric = MonitorSeries("MS-SSIM", monitor, interval=1) from ms_ssim import compute_metric score = compute_metric(gen, args.batch_size, num_images, args.latent, args.hyper_sphere) monitor_time.add(0) monitor_metric.add(0, score) elif args.validation_metric == "swd":"Sliced Wasserstein Distance") monitor_time = MonitorTimeElapsed( "SWD-ValidationTime", monitor, interval=1) monitor_metric = MonitorSeries("SWD", monitor, interval=1) nhoods_per_image = 128 nhood_size = 7 level_list = [128, 64, 32, 16] # TODO: use argument dir_repeats = 4 dirs_per_repeat = 128 from sliced_wasserstein import compute_metric score = compute_metric(di, gen, args.latent, num_batches, nhoods_per_image, nhood_size, level_list, dir_repeats, dirs_per_repeat, args.hyper_sphere) monitor_time.add(0) monitor_metric.add(0, score) # averaged in the log else:"Set `validation-metric` as either `ms-ssim` or `swd`.")"End validation")
Example #25
Source File: From nnabla-examples with Apache License 2.0 | 4 votes |
def main(): # Args args = get_args() # Context ctx = get_extension_context( args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) nn.set_auto_forward(True) # Config resolution_list = [4, 8, 16, 32, 64, 128] channel_list = [512, 512, 256, 128, 64, 32] side = 8 # Monitor monitor = Monitor(args.monitor_path) monitor_image_tile = MonitorImageTileWithName("Image Tile", monitor, num_images=side**2) # Generate # generate tile images imgs = [] for _ in range(side): img = generate_images(args.model_load_path, batch_size=side, use_bn=args.use_bn, n_latent=args.latent, hyper_sphere=args.hyper_sphere, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward, resolution_list=resolution_list, channel_list=channel_list) imgs.append(img) imgs = np.concatenate(imgs, axis=0) monitor_image_tile.add("GeneratedImage", imgs) # generate interpolated tile images imgs = [] for _ in range(side): img = generate_interpolated_images(args.model_load_path, batch_size=side, use_bn=args.use_bn, n_latent=args.latent, hyper_sphere=args.hyper_sphere, last_act=args.last_act, use_wscale=args.not_use_wscale, use_he_backward=args.use_he_backward, resolution_list=resolution_list, channel_list=channel_list) imgs.append(img) imgs = np.concatenate(imgs, axis=0) monitor_image_tile.add("GeneratedInterpolatedImage", imgs)