Python models.create_model() Examples
The following are 18
code examples of models.create_model().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
models
, or try the search function
.
Example #1
Source File: train.py From tf-imagenet with Apache License 2.0 | 7 votes |
def main(extra_flags): # Check no unknown flags was passed. assert len(extra_flags) >= 1 if len(extra_flags) > 1: raise ValueError('Received unknown flags: %s' % extra_flags[1:]) # Get parameters from FLAGS passed. params = parameters.make_params_from_flags() deploy.setup_env(params) parameters.save_params(params, params.train_dir) # TF log... tfversion = deploy.tensorflow_version_tuple() deploy.log_fn('TensorFlow: %i.%i' % (tfversion[0], tfversion[1])) # Create model and dataset. dataset = datasets.create_dataset( params.data_dir, params.data_name, params.data_subset) model = models.create_model(params.model, dataset) set_model_params(model, params) # Run CNN trainer. trainer = deploy.TrainerCNN(dataset, model, params) trainer.print_info() trainer.run()
Example #2
Source File: test_basic.py From dnn-quant-ocs with Apache License 2.0 | 6 votes |
def test_utils(): model = models.create_model(False, 'cifar10', 'resnet20_cifar', parallel=False) assert model is not None p = distiller.model_find_param(model, "") assert p is None # Search for a parameter by its "non-parallel" name p = distiller.model_find_param(model, "layer1.0.conv1.weight") assert p is not None # Search for a module name module_to_find = None for name, m in model.named_modules(): if name == "layer1.0.conv1": module_to_find = m break assert module_to_find is not None module_name = distiller.model_find_module_name(model, module_to_find) assert module_name == "layer1.0.conv1"
Example #3
Source File: test.py From MeshCNN with MIT License | 5 votes |
def run_test(epoch=-1): print('Running Test') opt = TestOptions().parse() opt.serial_batches = True # no shuffle dataset = DataLoader(opt) model = create_model(opt) writer = Writer(opt) # test writer.reset_counter() for i, data in enumerate(dataset): model.set_input(data) ncorrect, nexamples = model.test() writer.update_counter(ncorrect, nexamples) writer.print_acc(epoch, writer.acc) return writer.acc
Example #4
Source File: main.py From actor-observer with GNU General Public License v3.0 | 5 votes |
def main(): global args, best_top1 args = parse() if not args.no_logger: tee.Tee(args.cache + '/log.txt') print(vars(args)) seed(args.manual_seed) model, criterion, optimizer = create_model(args) if args.resume: best_top1 = checkpoints.load(args, model, optimizer) print(model) trainer = train.Trainer() loaders = get_dataset(args) train_loader = loaders[0] if args.evaluate: scores = validate(trainer, loaders, model, criterion, args) checkpoints.score_file(scores, "{}/model_000.txt".format(args.cache)) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: trainer.train_sampler.set_epoch(epoch) scores = {} scores.update(trainer.train(train_loader, model, criterion, optimizer, epoch, args)) scores.update(validate(trainer, loaders, model, criterion, args, epoch)) is_best = scores[args.metric] > best_top1 best_top1 = max(scores[args.metric], best_top1) checkpoints.save(epoch, args, model, optimizer, is_best, scores, args.metric) if not args.nopdb: pdb.set_trace()
Example #5
Source File: synthesizer.py From tacotron with MIT License | 5 votes |
def load(self, checkpoint_path, model_name='tacotron'): print('Constructing model: %s' % model_name) inputs = tf.placeholder(tf.int32, [1, None], 'inputs') input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') with tf.variable_scope('model') as scope: self.model = create_model(model_name, hparams) self.model.initialize(inputs, input_lengths) self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0]) print('Loading checkpoint: %s' % checkpoint_path) self.session = tf.Session() self.session.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(self.session, checkpoint_path)
Example #6
Source File: test_summarygraph.py From dnn-quant-ocs with Apache License 2.0 | 5 votes |
def name_test(dataset, arch): model = create_model(False, dataset, arch, parallel=False) modelp = create_model(False, dataset, arch, parallel=True) assert model is not None and modelp is not None mod_names = [mod_name for mod_name, _ in model.named_modules()] mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()] assert mod_names is not None and mod_names_p is not None assert len(mod_names)+1 == len(mod_names_p) for i in range(len(mod_names)-1): assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2]) logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2]))) assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1])
Example #7
Source File: common.py From dnn-quant-ocs with Apache License 2.0 | 5 votes |
def setup_test(arch, dataset, parallel): model = create_model(False, dataset, arch, parallel=parallel) assert model is not None # Create the masks zeros_mask_dict = {} for name, param in model.named_parameters(): masker = distiller.ParameterMasker(name) zeros_mask_dict[name] = masker return model, zeros_mask_dict
Example #8
Source File: test_infra.py From dnn-quant-ocs with Apache License 2.0 | 5 votes |
def test_load_negative(): with pytest.raises(FileNotFoundError): model = create_model(False, 'cifar10', 'resnet20_cifar') model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')
Example #9
Source File: test_infra.py From dnn-quant-ocs with Apache License 2.0 | 5 votes |
def test_load(): logger = logging.getLogger('simple_example') logger.setLevel(logging.INFO) model = create_model(False, 'cifar10', 'resnet20_cifar') model, compression_scheduler, start_epoch = load_checkpoint(model, '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar') assert compression_scheduler is not None assert start_epoch == 180
Example #10
Source File: thinning.py From dnn-quant-ocs with Apache License 2.0 | 5 votes |
def create_graph(dataset, arch): if dataset == 'imagenet': dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False) elif dataset == 'cifar10': dummy_input = torch.randn((1, 3, 32, 32)) assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) model = create_model(False, dataset, arch, parallel=False) assert model is not None return SummaryGraph(model, dummy_input.cuda())
Example #11
Source File: eval.py From tf-imagenet with Apache License 2.0 | 5 votes |
def main(extra_flags): # Check no unknown flags was passed. assert len(extra_flags) >= 1 if len(extra_flags) > 1: raise ValueError('Received unknown flags: %s' % extra_flags[1:]) # Get parameters from FLAGS passed. params = parameters.make_params_from_flags() deploy.setup_env(params) # Training parameters, update using json file. params = replace_with_train_params(params) # TF log... tfversion = deploy.tensorflow_version_tuple() deploy.log_fn('TensorFlow: %i.%i' % (tfversion[0], tfversion[1])) # Create model and dataset. dataset = datasets.create_dataset( params.data_dir, params.data_name, params.data_subset) model = models.create_model(params.model, dataset) train.set_model_params(model, params) # Set the number of batches to the size of the eval dataset. params = params._replace( num_batches=int(dataset.num_examples_per_epoch() / (params.batch_size * params.num_gpus))) # Run CNN trainer. trainer = deploy.TrainerCNN(dataset, model, params) trainer.print_info() trainer.run()
Example #12
Source File: synthesizer.py From arabic-tacotron-tts with MIT License | 5 votes |
def load(self, checkpoint_path, model_name='tacotron'): print('Constructing model: %s' % model_name) inputs = tf.placeholder(tf.int32, [1, None], 'inputs') input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') with tf.variable_scope('model') as scope: self.model = create_model(model_name, hparams) self.model.initialize(inputs, input_lengths) self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0]) print('Loading checkpoint: %s' % checkpoint_path) self.session = tf.Session() self.session.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(self.session, checkpoint_path)
Example #13
Source File: train.py From DMIT with MIT License | 5 votes |
def main(): opt = TrainOptions().parse() data_loader = CreateDataLoader(opt) dataset_size = len(data_loader) * opt.batch_size visualizer = Visualizer(opt) model = create_model(opt) start_epoch = model.start_epoch total_steps = start_epoch*dataset_size for epoch in range(start_epoch+1, opt.niter+opt.niter_decay+1): epoch_start_time = time.time() model.update_lr() save_result = True for i, data in enumerate(data_loader): iter_start_time = time.time() total_steps += opt.batch_size epoch_iter = total_steps - dataset_size * (epoch - 1) model.prepare_data(data) model.update_model() if save_result or total_steps % opt.display_freq == 0: save_result = save_result or total_steps % opt.update_html_freq == 0 visualizer.display_current_results(model.get_current_visuals(), epoch, ncols=1, save_result=save_result) save_result = False if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batch_size visualizer.print_current_errors(epoch, epoch_iter, errors, t) if opt.display_id > 0: visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) print('epoch {} cost dime {}'.format(epoch,time.time()-epoch_start_time)) model.save_ckpt(epoch) model.save_generator('latest') if epoch % opt.save_epoch_freq == 0: print('saving the generator at the end of epoch {}, iters {}'.format(epoch, total_steps)) model.save_generator(epoch)
Example #14
Source File: test.py From DMIT with MIT License | 5 votes |
def main(): opt = TestOptions().parse() opt.is_flip = False opt.batchSize = 1 data_loader = CreateDataLoader(opt) model = create_model(opt) web_dir = os.path.join(opt.results_dir, 'test') webpage = html.HTML(web_dir, 'task {}'.format(opt.exp_name)) for i, data in enumerate(islice(data_loader, opt.how_many)): print('process input image %3.3d/%3.3d' % (i, opt.how_many)) results = model.translation(data) img_path = 'image%3.3i' % i save_images(webpage, results, img_path, None, width=opt.fine_size) webpage.save()
Example #15
Source File: synthesizer.py From vae_tacotron with MIT License | 5 votes |
def load(self, checkpoint_path, model_name='tacotron'): print('Constructing model: %s' % model_name) inputs = tf.placeholder(tf.int32, [1, None], 'inputs') reference_mel = tf.placeholder(tf.float32, [1, None, 80], 'reference_mel') input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') with tf.variable_scope('model') as scope: self.model = create_model(model_name, hparams) self.model.initialize(inputs, input_lengths, reference_mel=reference_mel) self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0]) print('Loading checkpoint: %s' % checkpoint_path) self.session = tf.Session() self.session.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(self.session, checkpoint_path)
Example #16
Source File: freeze.py From TF_SpeechRecoChallenge with Apache License 2.0 | 4 votes |
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, clip_stride_ms, window_size_ms, window_stride_ms, dct_coefficient_count, resnet_size, model_architecture): """Creates an audio model with the nodes needed for inference. Uses the supplied arguments to create a model, and inserts the input and output nodes that are needed to use the graph for inference. Args: wanted_words: Comma-separated list of the words we're trying to recognize. sample_rate: How many samples per second are in the input audio files. clip_duration_ms: How many samples to analyze for the audio pattern. clip_stride_ms: How often to run recognition. Useful for models with cache. window_size_ms: Time slice duration to estimate frequencies from. window_stride_ms: How far apart time slices should be. dct_coefficient_count: Number of frequency bands to analyze. model_architecture: Name of the kind of model to generate. """ words_list = input_data.prepare_words_list(wanted_words.split(',')) model_settings = models.prepare_model_settings( len(words_list), sample_rate, clip_duration_ms, window_size_ms, window_stride_ms, dct_coefficient_count, resnet_size) runtime_settings = {'clip_stride_ms': clip_stride_ms} wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data') decoded_sample_data = contrib_audio.decode_wav( wav_data_placeholder, desired_channels=1, desired_samples=model_settings['desired_samples'], name='decoded_sample_data') spectrogram = contrib_audio.audio_spectrogram( decoded_sample_data.audio, window_size=model_settings['window_size_samples'], stride=model_settings['window_stride_samples'], magnitude_squared=True) fingerprint_input = contrib_audio.mfcc( spectrogram, decoded_sample_data.sample_rate, dct_coefficient_count=dct_coefficient_count) fingerprint_frequency_size = model_settings['dct_coefficient_count'] fingerprint_time_size = model_settings['spectrogram_length'] reshaped_input = tf.reshape(fingerprint_input, [ -1, fingerprint_time_size * fingerprint_frequency_size ]) logits = models.create_model( reshaped_input, model_settings, model_architecture, is_training=False, runtime_settings=runtime_settings) # Create an output to use for inference. tf.nn.softmax(logits, name='labels_softmax')
Example #17
Source File: freeze.py From adversarial_audio with MIT License | 4 votes |
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, clip_stride_ms, window_size_ms, window_stride_ms, dct_coefficient_count, model_architecture): """Creates an audio model with the nodes needed for inference. Uses the supplied arguments to create a model, and inserts the input and output nodes that are needed to use the graph for inference. Args: wanted_words: Comma-separated list of the words we're trying to recognize. sample_rate: How many samples per second are in the input audio files. clip_duration_ms: How many samples to analyze for the audio pattern. clip_stride_ms: How often to run recognition. Useful for models with cache. window_size_ms: Time slice duration to estimate frequencies from. window_stride_ms: How far apart time slices should be. dct_coefficient_count: Number of frequency bands to analyze. model_architecture: Name of the kind of model to generate. """ words_list = input_data.prepare_words_list(wanted_words.split(',')) model_settings = models.prepare_model_settings( len(words_list), sample_rate, clip_duration_ms, window_size_ms, window_stride_ms, dct_coefficient_count) runtime_settings = {'clip_stride_ms': clip_stride_ms} wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data') decoded_sample_data = contrib_audio.decode_wav( wav_data_placeholder, desired_channels=1, desired_samples=model_settings['desired_samples'], name='decoded_sample_data') spectrogram = contrib_audio.audio_spectrogram( decoded_sample_data.audio, window_size=model_settings['window_size_samples'], stride=model_settings['window_stride_samples'], magnitude_squared=True) fingerprint_input = contrib_audio.mfcc( spectrogram, decoded_sample_data.sample_rate, dct_coefficient_count=dct_coefficient_count) fingerprint_frequency_size = model_settings['dct_coefficient_count'] fingerprint_time_size = model_settings['spectrogram_length'] reshaped_input = tf.reshape(fingerprint_input, [ -1, fingerprint_time_size * fingerprint_frequency_size ]) logits = models.create_model( reshaped_input, model_settings, model_architecture, is_training=False, runtime_settings=runtime_settings) # Create an output to use for inference. tf.nn.softmax(logits, name='labels_softmax')
Example #18
Source File: freeze.py From honk with MIT License | 4 votes |
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, clip_stride_ms, window_size_ms, window_stride_ms, dct_coefficient_count, model_architecture): """Creates an audio model with the nodes needed for inference. Uses the supplied arguments to create a model, and inserts the input and output nodes that are needed to use the graph for inference. Args: wanted_words: Comma-separated list of the words we're trying to recognize. sample_rate: How many samples per second are in the input audio files. clip_duration_ms: How many samples to analyze for the audio pattern. clip_stride_ms: How often to run recognition. Useful for models with cache. window_size_ms: Time slice duration to estimate frequencies from. window_stride_ms: How far apart time slices should be. dct_coefficient_count: Number of frequency bands to analyze. model_architecture: Name of the kind of model to generate. """ words_list = input_data.prepare_words_list(wanted_words.split(',')) model_settings = models.prepare_model_settings( len(words_list), sample_rate, clip_duration_ms, window_size_ms, window_stride_ms, dct_coefficient_count) runtime_settings = {'clip_stride_ms': clip_stride_ms} wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data') decoded_sample_data = contrib_audio.decode_wav( wav_data_placeholder, desired_channels=1, desired_samples=model_settings['desired_samples'], name='decoded_sample_data') spectrogram = contrib_audio.audio_spectrogram( decoded_sample_data.audio, window_size=model_settings['window_size_samples'], stride=model_settings['window_stride_samples'], magnitude_squared=True) fingerprint_input = contrib_audio.mfcc( spectrogram, decoded_sample_data.sample_rate, dct_coefficient_count=dct_coefficient_count) fingerprint_frequency_size = model_settings['dct_coefficient_count'] fingerprint_time_size = model_settings['spectrogram_length'] reshaped_input = tf.reshape(fingerprint_input, [ -1, fingerprint_time_size * fingerprint_frequency_size ]) logits = models.create_model( reshaped_input, model_settings, model_architecture, is_training=False, runtime_settings=runtime_settings) # Create an output to use for inference. tf.nn.softmax(logits, name='labels_softmax')