Python tensorflow.flags() Examples
The following are 30
code examples of tensorflow.flags().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: sample.py From IMPLEMENTATION_Variational-Auto-Encoder with MIT License | 7 votes |
def main(): flags = tf.flags flags.DEFINE_integer("latent_dim", 64, "Dimension of latent space.") flags.DEFINE_integer("obs_dim", 12288, "Dimension of observation space.") flags.DEFINE_integer("batch_size", 60, "Batch size.") flags.DEFINE_integer("epochs", 500, "As it said") flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.") FLAGS = flags.FLAGS kwargs = { 'latent_dim': FLAGS.latent_dim, 'observation_dim': FLAGS.obs_dim, 'generator': conv_anime_decoder, 'obs_distrib': 'Gaussian' } g = GENERATOR(**kwargs) g.load_pretrained("weights/vae_anime/generator") z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim]) samples = g.e2x(z) print samples.shape show_samples(samples, 4, 15, [64, 64, 3], name='small_samples', shift=True)
Example #2
Source File: config.py From Question_Answering_Models with MIT License | 6 votes |
def main(_): config = flags.FLAGS if config.mode == "train": train(config) elif config.mode == "prepro": prepro(config) elif config.mode == "debug": config.num_steps = 2 config.val_num_batches = 1 config.checkpoint = 1 config.period = 1 train(config) elif config.mode == "test": test(config) else: print("Unknown mode, you must choose mode from [train/prepro/debug/test]") exit(0)
Example #3
Source File: train.py From text-gan-tensorflow with MIT License | 6 votes |
def get_supervisor(model): saver = tf.train.Saver() summary_writer = tf.summary.FileWriter(FLAGS.model_dir) supervisor = tf.train.Supervisor( logdir=FLAGS.model_dir, is_chief=True, saver=saver, init_op=set_initial_ops(), summary_op=tf.summary.merge_all(), summary_writer=summary_writer, save_summaries_secs=100, # TODO: add as flags save_model_secs=1000, global_step=model.global_step, ) return supervisor
Example #4
Source File: config.py From QANet with MIT License | 6 votes |
def main(_): config = flags.FLAGS if config.mode == "train": train(config) elif config.mode == "prepro": prepro(config) elif config.mode == "debug": config.num_steps = 2 config.val_num_batches = 1 config.checkpoint = 1 config.period = 1 train(config) elif config.mode == "test": test(config) elif config.mode == "demo": demo(config) else: print("Unknown mode") exit(0)
Example #5
Source File: config.py From AIchallenger2018_MachineReadingComprehension with MIT License | 6 votes |
def main(_): config = flags.FLAGS os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu # 选择一块gpu if config.mode == "train": train(config) elif config.mode == "prepro": data_process.prepro(config) elif config.mode == "debug": config.num_steps = 2 config.val_num_batches = 1 config.checkpoint = 1 config.period = 1 train(config) elif config.mode == "test": test(config) elif config.mode == "examine": examine_dev(config) elif config.mode == "save_dev": save_dev(config) elif config.mode == "save_test": save_test(config) else: print("Unknown mode") exit(0)
Example #6
Source File: config.py From AIchallenger2018_MachineReadingComprehension with MIT License | 6 votes |
def main(_): config = flags.FLAGS os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu # 选择一块gpu if config.mode == "train": train(config) elif config.mode == "prepro": data_process_addAnswer.prepro(config) elif config.mode == "test": test(config) elif config.mode == "examine": examine_dev(config) elif config.mode == "save_dev": save_dev(config) elif config.mode == "save_test": save_test(config) else: print("Unknown mode") exit(0)
Example #7
Source File: config.py From Question_Answering_Models with MIT License | 6 votes |
def main(_): config = flags.FLAGS if config.mode == "train": train(config) elif config.mode == "prepro": prepro(config) elif config.mode == "debug": config.num_steps = 2 config.val_num_batches = 1 config.checkpoint = 1 config.period = 1 train(config) elif config.mode == "test": test(config) else: print("Unknown mode, you must choose mode from [train/prepro/debug/test]") exit(0)
Example #8
Source File: config.py From Question_Answering_Models with MIT License | 6 votes |
def main(_): config = flags.FLAGS if config.mode == "train": train(config) elif config.mode == "prepro": prepro(config) elif config.mode == "debug": config.num_steps = 2 config.val_num_batches = 1 config.checkpoint = 1 config.period = 1 train(config) elif config.mode == "test": test(config) else: print("Unknown mode, you must choose mode from [train/prepro/debug/test]") exit(0)
Example #9
Source File: config.py From AmusingPythonCodes with MIT License | 6 votes |
def main(_): config = flags.FLAGS if config.mode == "train": train(config) elif config.mode == "prepro": prepro(config) elif config.mode == "debug": config.num_steps = 2 config.val_num_batches = 1 config.checkpoint = 1 config.period = 1 train(config) elif config.mode == "test": if config.use_cudnn: print("Warning: Due to a known bug in Tensorlfow, the parameters of CudnnGRU may not be properly restored.") test(config) else: print("Unknown mode") exit(0)
Example #10
Source File: config.py From R-Net with MIT License | 6 votes |
def main(_): config = flags.FLAGS if config.mode == "train": train(config) elif config.mode == "prepro": prepro(config) elif config.mode == "debug": config.num_steps = 2 config.val_num_batches = 1 config.checkpoint = 1 config.period = 1 train(config) elif config.mode == "test": test(config) else: print("Unknown mode") exit(0)
Example #11
Source File: config.py From QGforQA with MIT License | 6 votes |
def main(_): config = flags.FLAGS if config.mode == "get_vocab": get_vocab(config) elif config.mode == "prepare": prepare(config) elif config.mode == "train": train(config) elif config.mode == "train_rl": train_rl(config) elif config.mode == "train_qpp": train_qpp(config) elif config.mode == "train_qap": train_qap(config) elif config.mode == "train_qqp_qap": train_qqp_qap(config) elif config.mode == "test": test(config) else: print("Unknown mode") exit(0)
Example #12
Source File: vae_train_anime.py From IMPLEMENTATION_Variational-Auto-Encoder with MIT License | 5 votes |
def main(): flags = tf.flags flags.DEFINE_integer("latent_dim", 64, "Dimension of latent space.") flags.DEFINE_integer("obs_dim", 12288, "Dimension of observation space.") flags.DEFINE_integer("batch_size", 64, "Batch size.") flags.DEFINE_integer("epochs", 500, "As it said") flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.") FLAGS = flags.FLAGS kwargs = { 'latent_dim': FLAGS.latent_dim, 'batch_size': FLAGS.batch_size, 'observation_dim': FLAGS.obs_dim, 'encoder': conv_anime_encoder, 'decoder': conv_anime_decoder, 'observation_distribution': 'Gaussian' } vae = VAE(**kwargs) provider = Anime() tbar = tqdm(range(FLAGS.epochs)) for epoch in tbar: training_loss = 0. for _ in range(FLAGS.updates_per_epoch): x = provider.next_batch(FLAGS.batch_size) loss = vae.update(x) training_loss += loss training_loss /= FLAGS.updates_per_epoch s = "Loss: {:.4f}".format(training_loss) tbar.set_description(s) z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim]) samples = vae.z2x(z)[0] show_samples(samples, 8, 8, [64, 64, 3], name='samples') vae.save_generator('weights/vae_anime/generator')
Example #13
Source File: pythonLanguageModel.py From pycodesuggest with MIT License | 5 votes |
def print_flags(flags): for flag in flags.__flags: val = getattr(flags, flag) if not isinstance(val, bool) or val: print("%s=%s" % (flag, val)) print() print()
Example #14
Source File: vae_train.py From IMPLEMENTATION_Variational-Auto-Encoder with MIT License | 5 votes |
def main(): flags = tf.flags flags.DEFINE_integer("latent_dim", 2, "Dimension of latent space.") flags.DEFINE_integer("batch_size", 128, "Batch size.") flags.DEFINE_integer("epochs", 500, "As it said") flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.") flags.DEFINE_string("data_dir", 'mnist', "Tensorflow demo data download position.") FLAGS = flags.FLAGS kwargs = { 'latent_dim': FLAGS.latent_dim, 'batch_size': FLAGS.batch_size, 'encoder': fc_mnist_encoder, 'decoder': fc_mnist_decoder } vae = VAE(**kwargs) mnist = input_data.read_data_sets(train_dir=FLAGS.data_dir) tbar = tqdm(range(FLAGS.epochs)) for epoch in tbar: training_loss = 0. for _ in range(FLAGS.updates_per_epoch): x, _ = mnist.train.next_batch(FLAGS.batch_size) loss = vae.update(x) training_loss += loss training_loss /= FLAGS.updates_per_epoch s = "Loss: {:.4f}".format(training_loss) tbar.set_description(s) z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim]) samples = vae.z2x(z)[0] show_samples(samples, 10, 10, [28, 28], name='samples') show_latent_scatter(vae, mnist, name='latent') vae.save_generator('weights/vae_mnist/generator')
Example #15
Source File: config.py From QGforQA with MIT License | 5 votes |
def main(_): config = flags.FLAGS if config.mode == "train_for_qg": train_for_qg(config) elif config.mode == "test_qa_for_qg": test_qa_for_qg(config) else: print("Unknown mode") exit(0)
Example #16
Source File: config.py From QGforQA with MIT License | 5 votes |
def main(_): config = flags.FLAGS if config.mode == "prepare": prepare(config) elif config.mode == "test": test(config) elif config.mode == "train": train(config) else: print("Unknown mode") exit(0)
Example #17
Source File: train.py From HMEAE with MIT License | 5 votes |
def main(_): config = flags.FLAGS os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu extractor = utils.Extractor() extractor.Extract() loader = utils.Loader() t_data = loader.load_trigger() a_data = loader.load_argument() trigger = DMCNN(t_data,a_data,loader.maxlen,loader.max_argument_len,loader.wordemb) a_data_process = trigger.train_trigger() argument = DMCNN(t_data,a_data_process,loader.maxlen,loader.max_argument_len,loader.wordemb,stage=config.mode,classify=config.classify) argument.train_argument()
Example #18
Source File: inference_demo.py From Gun-Detector with Apache License 2.0 | 5 votes |
def _validate_flags(): flags.register_validator('checkpoint_path', bool, 'Must provide `checkpoint_path`.') flags.register_validator( 'generated_x_dir', lambda x: False if (FLAGS.image_set_y_glob and not x) else True, 'Must provide `generated_x_dir`.') flags.register_validator( 'generated_y_dir', lambda x: False if (FLAGS.image_set_x_glob and not x) else True, 'Must provide `generated_y_dir`.')
Example #19
Source File: train.py From text-gan-tensorflow with MIT License | 5 votes |
def get_sess_config(): # gpu_options = tf.GPUOptions( # per_process_gpu_memory_fraction=self.gpu_memory_fraction, # allow_growth=True) # seems to be not working sess_config = tf.ConfigProto( # log_device_placement=True, inter_op_parallelism_threads=8, # TODO: add as flags # allow_soft_placement=True, # gpu_options=gpu_options) ) return sess_config
Example #20
Source File: config.py From QGforQA with MIT License | 5 votes |
def main(_): config = flags.FLAGS if config.mode == "prepare": prepare(config) elif config.mode == "train": train(config) elif config.mode == "test": test(config) else: print("Unknown mode") exit(0)
Example #21
Source File: tf_t2t.py From sgnmt with Apache License 2.0 | 5 votes |
def vocab_size(self): return self._vocab_size # Define flags from the t2t binaries
Example #22
Source File: query.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def validate_flags(): """Validates flags are set to acceptable values.""" if FLAGS.cloud_mlengine_model_name: assert not FLAGS.server assert not FLAGS.servable_name else: assert FLAGS.server assert FLAGS.servable_name
Example #23
Source File: t2t_trainer.py From fine-lm with MIT License | 5 votes |
def save_metadata(hparams): """Saves FLAGS and hparams to output_dir.""" output_dir = os.path.expanduser(FLAGS.output_dir) if not tf.gfile.Exists(output_dir): tf.gfile.MakeDirs(output_dir) # Save FLAGS in txt file if hasattr(FLAGS, "flags_into_string"): flags_str = FLAGS.flags_into_string() t2t_flags_str = "\n".join([ "--%s=%s" % (f.name, f.value) for f in FLAGS.flags_by_module_dict()["tensor2tensor.utils.flags"] ]) else: flags_dict = FLAGS.__dict__["__flags"] flags_str = "\n".join( ["--%s=%s" % (name, str(f)) for (name, f) in flags_dict.items()]) t2t_flags_str = None flags_txt = os.path.join(output_dir, "flags.txt") with tf.gfile.Open(flags_txt, "w") as f: f.write(flags_str) if t2t_flags_str: t2t_flags_txt = os.path.join(output_dir, "flags_t2t.txt") with tf.gfile.Open(t2t_flags_txt, "w") as f: f.write(t2t_flags_str) # Save hparams as hparams.json hparams_fname = os.path.join(output_dir, "hparams.json") with tf.gfile.Open(hparams_fname, "w") as f: f.write(hparams.to_json(indent=0, sort_keys=True))
Example #24
Source File: transformer_model.py From fine-lm with MIT License | 5 votes |
def __init__(self, processor_configuration): """Creates the Transformer estimator. Args: processor_configuration: A ProcessorConfiguration protobuffer with the transformer fields populated. """ # Do the pre-setup tensor2tensor requires for flags and configurations. transformer_config = processor_configuration["transformer"] FLAGS.output_dir = transformer_config["model_dir"] usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) data_dir = os.path.expanduser(transformer_config["data_dir"]) # Create the basic hyper parameters. self.hparams = trainer_lib.create_hparams( transformer_config["hparams_set"], transformer_config["hparams"], data_dir=data_dir, problem_name=transformer_config["problem"]) decode_hp = decoding.decode_hparams() decode_hp.add_hparam("shards", 1) decode_hp.add_hparam("shard_id", 0) # Create the estimator and final hyper parameters. self.estimator = trainer_lib.create_estimator( transformer_config["model"], self.hparams, t2t_trainer.create_run_config(self.hparams), decode_hparams=decode_hp, use_tpu=False) # Fetch the vocabulary and other helpful variables for decoding. self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"] self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"] self.const_array_size = 10000 # Prepare the Transformer's debug data directory. run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*"))) for run_dir in run_dirs: shutil.rmtree(run_dir)
Example #25
Source File: query.py From fine-lm with MIT License | 5 votes |
def validate_flags(): """Validates flags are set to acceptable values.""" if FLAGS.cloud_mlengine_model_name: assert not FLAGS.server assert not FLAGS.servable_name else: assert FLAGS.server assert FLAGS.servable_name
Example #26
Source File: t2t_translate_all.py From BERT with Apache License 2.0 | 5 votes |
def main(_): tf.logging.set_verbosity(tf.logging.INFO) # pylint: disable=unused-variable model_dir = os.path.expanduser(FLAGS.model_dir) translations_dir = os.path.expanduser(FLAGS.translations_dir) source = os.path.expanduser(FLAGS.source) tf.gfile.MakeDirs(translations_dir) translated_base_file = os.path.join(translations_dir, FLAGS.problem) # Copy flags.txt with the original time, so t2t-bleu can report correct # relative time. flags_path = os.path.join(translations_dir, FLAGS.problem + "-flags.txt") if not os.path.exists(flags_path): shutil.copy2(os.path.join(model_dir, "flags.txt"), flags_path) locals_and_flags = {"FLAGS": FLAGS} for model in bleu_hook.stepfiles_iterator(model_dir, FLAGS.wait_minutes, FLAGS.min_steps): tf.logging.info("Translating " + model.filename) out_file = translated_base_file + "-" + str(model.steps) locals_and_flags.update(locals()) if os.path.exists(out_file): tf.logging.info(out_file + " already exists, so skipping it.") else: tf.logging.info("Translating " + out_file) params = ( "--t2t_usr_dir={FLAGS.t2t_usr_dir} --output_dir={model_dir} " "--data_dir={FLAGS.data_dir} --problem={FLAGS.problem} " "--decode_hparams=beam_size={FLAGS.beam_size},alpha={FLAGS.alpha} " "--model={FLAGS.model} --hparams_set={FLAGS.hparams_set} " "--checkpoint_path={model.filename} --decode_from_file={source} " "--decode_to_file={out_file} --keep_timestamp" ).format(**locals_and_flags) command = FLAGS.decoder_command.format(**locals()) tf.logging.info("Running:\n" + command) os.system(command) # pylint: enable=unused-variable
Example #27
Source File: t2t_trainer.py From BERT with Apache License 2.0 | 5 votes |
def save_metadata(hparams): """Saves FLAGS and hparams to output_dir.""" output_dir = os.path.expanduser(FLAGS.output_dir) if not tf.gfile.Exists(output_dir): tf.gfile.MakeDirs(output_dir) # Save FLAGS in txt file if hasattr(FLAGS, "flags_into_string"): flags_str = FLAGS.flags_into_string() t2t_flags_str = "\n".join([ "--%s=%s" % (f.name, f.value) for f in FLAGS.flags_by_module_dict()["tensor2tensor.utils.flags"] ]) else: flags_dict = FLAGS.__dict__["__flags"] flags_str = "\n".join( ["--%s=%s" % (name, str(f)) for (name, f) in flags_dict.items()]) t2t_flags_str = None flags_txt = os.path.join(output_dir, "flags.txt") with tf.gfile.Open(flags_txt, "w") as f: f.write(flags_str) if t2t_flags_str: t2t_flags_txt = os.path.join(output_dir, "flags_t2t.txt") with tf.gfile.Open(t2t_flags_txt, "w") as f: f.write(t2t_flags_str) # Save hparams as hparams.json new_hparams = hparams_lib.copy_hparams(hparams) # Modality class is not JSON serializable so remove. new_hparams.del_hparam("modality") hparams_fname = os.path.join(output_dir, "hparams.json") with tf.gfile.Open(hparams_fname, "w") as f: f.write(new_hparams.to_json(indent=0, sort_keys=True))
Example #28
Source File: t2t_eval.py From BERT with Apache License 2.0 | 5 votes |
def main(_): tf.logging.set_verbosity(tf.logging.INFO) trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) hparams = trainer_lib.create_hparams( FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir, problem_name=FLAGS.problem) # set appropriate dataset-split, if flags.eval_use_test_set. dataset_split = "test" if FLAGS.eval_use_test_set else None dataset_kwargs = {"dataset_split": dataset_split} eval_input_fn = hparams.problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs) config = t2t_trainer.create_run_config(hparams) # summary-hook in tf.estimator.EstimatorSpec requires # hparams.model_dir to be set. hparams.add_hparam("model_dir", config.model_dir) estimator = trainer_lib.create_estimator( FLAGS.model, hparams, config, use_tpu=FLAGS.use_tpu) ckpt_iter = trainer_lib.next_checkpoint( hparams.model_dir, FLAGS.eval_timeout_mins) for ckpt_path in ckpt_iter: predictions = estimator.evaluate( eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path) tf.logging.info(predictions)
Example #29
Source File: transformer_model.py From BERT with Apache License 2.0 | 5 votes |
def __init__(self, processor_configuration): """Creates the Transformer estimator. Args: processor_configuration: A ProcessorConfiguration protobuffer with the transformer fields populated. """ # Do the pre-setup tensor2tensor requires for flags and configurations. transformer_config = processor_configuration["transformer"] FLAGS.output_dir = transformer_config["model_dir"] usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) data_dir = os.path.expanduser(transformer_config["data_dir"]) # Create the basic hyper parameters. self.hparams = trainer_lib.create_hparams( transformer_config["hparams_set"], transformer_config["hparams"], data_dir=data_dir, problem_name=transformer_config["problem"]) decode_hp = decoding.decode_hparams() decode_hp.add_hparam("shards", 1) decode_hp.add_hparam("shard_id", 0) # Create the estimator and final hyper parameters. self.estimator = trainer_lib.create_estimator( transformer_config["model"], self.hparams, t2t_trainer.create_run_config(self.hparams), decode_hparams=decode_hp, use_tpu=False) # Fetch the vocabulary and other helpful variables for decoding. self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"] self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"] self.const_array_size = 10000 # Prepare the Transformer's debug data directory. run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*"))) for run_dir in run_dirs: shutil.rmtree(run_dir)
Example #30
Source File: query.py From BERT with Apache License 2.0 | 5 votes |
def validate_flags(): """Validates flags are set to acceptable values.""" if FLAGS.cloud_mlengine_model_name: assert not FLAGS.server assert not FLAGS.servable_name else: assert FLAGS.server assert FLAGS.servable_name