Python tensor2tensor.utils.registry.problem() Examples

The following are 30 code examples of tensor2tensor.utils.registry.problem(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensor2tensor.utils.registry , or try the search function .
Example #1
Source File: query.py    From fine-lm with MIT License 6 votes vote down vote up
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  validate_flags()
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  problem = registry.problem(FLAGS.problem)
  hparams = tf.contrib.training.HParams(
      data_dir=os.path.expanduser(FLAGS.data_dir))
  problem.get_hparams(hparams)
  request_fn = make_request_fn()
  while True:
    inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ")
    outputs = serving_utils.predict([inputs], problem, request_fn)
    outputs, = outputs
    output, score = outputs
    print_str = """
Input:
{inputs}

Output (Score {score:.3f}):
{output}
    """
    print(print_str.format(inputs=inputs, output=output, score=score))
    if FLAGS.inputs_once:
      break 
Example #2
Source File: t2t_decoder.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def decode(estimator, hparams, decode_hp):
  """Decode from estimator. Interactive, from file, or from dataset."""
  if FLAGS.decode_interactive:
    if estimator.config.use_tpu:
      raise ValueError("TPU can only decode from dataset.")
    decoding.decode_interactively(estimator, hparams, decode_hp,
                                  checkpoint_path=FLAGS.checkpoint_path)
  elif FLAGS.decode_from_file:
    decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
                              decode_hp, FLAGS.decode_to_file,
                              checkpoint_path=FLAGS.checkpoint_path)
    if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
      ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + ".index")
      os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
  else:
    decoding.decode_from_dataset(
        estimator,
        FLAGS.problem,
        hparams,
        decode_hp,
        decode_to_file=FLAGS.decode_to_file,
        dataset_split="test" if FLAGS.eval_use_test_set else None,
        checkpoint_path=FLAGS.checkpoint_path) 
Example #3
Source File: data_reader_test.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def testBasicExampleReading(self):
    dataset = self.problem.dataset(
        tf.estimator.ModeKeys.TRAIN,
        data_dir=self.data_dir,
        shuffle_files=False)
    examples = dataset.make_one_shot_iterator().get_next()
    with tf.train.MonitoredSession() as sess:
      # Check that there are multiple examples that have the right fields of the
      # right type (lists of int/float).
      for _ in range(10):
        ex_val = sess.run(examples)
        inputs, targets, floats = (ex_val["inputs"], ex_val["targets"],
                                   ex_val["floats"])
        self.assertEqual(np.int64, inputs.dtype)
        self.assertEqual(np.int64, targets.dtype)
        self.assertEqual(np.float32, floats.dtype)
        for field in [inputs, targets, floats]:
          self.assertGreater(len(field), 0) 
Example #4
Source File: build_vocab.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def main(_):
  problem = registry.problem(FLAGS.problem)

  # We make the assumption that the problem is a subclass of Text2TextProblem.
  assert isinstance(problem, text_problems.Text2TextProblem)

  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)

  tf.gfile.MakeDirs(data_dir)
  tf.gfile.MakeDirs(tmp_dir)

  tf.logging.info("Saving vocabulary to data_dir: %s" % data_dir)

  problem.get_or_create_vocab(data_dir, tmp_dir)

  tf.logging.info("Saved vocabulary file: " +
                  os.path.join(data_dir, problem.vocab_filename)) 
Example #5
Source File: multimodel_test.py    From fine-lm with MIT License 6 votes vote down vote up
def testMultiModel(self):
    x = np.random.random_integers(0, high=255, size=(3, 5, 5, 3))
    y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1))
    hparams = multimodel.multimodel_tiny()
    hparams.add_hparam("data_dir", "")
    problem = registry.problem("image_cifar10")
    p_hparams = problem.get_hparams(hparams)
    hparams.problem_hparams = p_hparams
    with self.test_session() as session:
      features = {
          "inputs": tf.constant(x, dtype=tf.int32),
          "targets": tf.constant(y, dtype=tf.int32),
          "target_space_id": tf.constant(1, dtype=tf.int32),
      }
      model = multimodel.MultiModel(
          hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
      logits, _ = model(features)
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (3, 1, 1, 1, 10)) 
Example #6
Source File: image_transformer_2d_test.py    From fine-lm with MIT License 6 votes vote down vote up
def _test_img2img_transformer(self, net):
    batch_size = 3
    hparams = image_transformer_2d.img2img_transformer2d_tiny()
    hparams.data_dir = ""
    p_hparams = registry.problem("image_celeba").get_hparams(hparams)
    inputs = np.random.random_integers(0, high=255, size=(3, 4, 4, 3))
    targets = np.random.random_integers(0, high=255, size=(3, 8, 8, 3))
    with self.test_session() as session:
      features = {
          "inputs": tf.constant(inputs, dtype=tf.int32),
          "targets": tf.constant(targets, dtype=tf.int32),
          "target_space_id": tf.constant(1, dtype=tf.int32),
      }
      model = net(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
      logits, _ = model(features)
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (batch_size, 8, 8, 3, 256)) 
Example #7
Source File: hparams_lib.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def create_hparams(hparams_set,
                   hparams_overrides_str="",
                   data_dir=None,
                   problem_name=None,
                   hparams_path=None):
  """Create HParams with data_dir and problem hparams, if kwargs provided."""
  hparams = registry.hparams(hparams_set)
  if hparams_path and tf.gfile.Exists(hparams_path):
    hparams = create_hparams_from_json(hparams_path, hparams)
  if data_dir:
    hparams.add_hparam("data_dir", data_dir)
  if hparams_overrides_str:
    tf.logging.info("Overriding hparams in %s with %s", hparams_set,
                    hparams_overrides_str)
    hparams = hparams.parse(hparams_overrides_str)
  if problem_name:
    add_problem_hparams(hparams, problem_name)
  return hparams 
Example #8
Source File: slicenet_test.py    From fine-lm with MIT License 6 votes vote down vote up
def testSliceNet(self):
    x = np.random.random_integers(0, high=255, size=(3, 5, 5, 3))
    y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1))
    hparams = slicenet.slicenet_params1_tiny()
    hparams.add_hparam("data_dir", "")
    problem = registry.problem("image_cifar10")
    p_hparams = problem.get_hparams(hparams)
    hparams.problem_hparams = p_hparams
    with self.test_session() as session:
      features = {
          "inputs": tf.constant(x, dtype=tf.int32),
          "targets": tf.constant(y, dtype=tf.int32),
          "target_space_id": tf.constant(1, dtype=tf.int32),
      }
      model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN,
                                p_hparams)
      logits, _ = model(features)
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (3, 1, 1, 1, 10)) 
Example #9
Source File: data_reader_test.py    From fine-lm with MIT License 6 votes vote down vote up
def testBasicExampleReading(self):
    dataset = self.problem.dataset(
        tf.estimator.ModeKeys.TRAIN,
        data_dir=self.data_dir,
        shuffle_files=False)
    examples = dataset.make_one_shot_iterator().get_next()
    with tf.train.MonitoredSession() as sess:
      # Check that there are multiple examples that have the right fields of the
      # right type (lists of int/float).
      for _ in range(10):
        ex_val = sess.run(examples)
        inputs, targets, floats = (ex_val["inputs"], ex_val["targets"],
                                   ex_val["floats"])
        self.assertEqual(np.int64, inputs.dtype)
        self.assertEqual(np.int64, targets.dtype)
        self.assertEqual(np.float32, floats.dtype)
        for field in [inputs, targets, floats]:
          self.assertGreater(len(field), 0) 
Example #10
Source File: t2t_datagen.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def generate_data_for_registered_problem(problem_name):
  """Generate data for a registered problem."""
  tf.logging.info("Generating data for %s.", problem_name)
  if FLAGS.num_shards:
    raise ValueError("--num_shards should not be set for registered Problem.")
  problem = registry.problem(problem_name)
  task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
  if task_id is None and problem.multiprocess_generate:
    if FLAGS.task_id_start != -1:
      assert FLAGS.task_id_end != -1
      task_id_start = FLAGS.task_id_start
      task_id_end = FLAGS.task_id_end
    else:
      task_id_start = 0
      task_id_end = problem.num_generate_tasks
    pool = multiprocessing.Pool(processes=FLAGS.num_concurrent_processes)
    problem.prepare_to_generate(data_dir, tmp_dir)
    args = [(problem_name, data_dir, tmp_dir, task_id)
            for task_id in range(task_id_start, task_id_end)]
    pool.map(generate_data_in_process, args)
  else:
    problem.generate_data(data_dir, tmp_dir, task_id) 
Example #11
Source File: t2t_datagen.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def generate_data_for_env_problem(problem_name):
  """Generate data for `EnvProblem`s."""
  assert FLAGS.env_problem_max_env_steps > 0, ("--env_problem_max_env_steps "
                                               "should be greater than zero")
  assert FLAGS.env_problem_batch_size > 0, ("--env_problem_batch_size should be"
                                            " greather than zero")
  problem = registry.env_problem(problem_name)
  task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
  # TODO(msaffar): Handle large values for env_problem_batch_size where we
  #  cannot create that many environments within the same process.
  problem.initialize(batch_size=FLAGS.env_problem_batch_size)
  env_problem_utils.play_env_problem_randomly(
      problem, num_steps=FLAGS.env_problem_max_env_steps)
  problem.generate_data(data_dir=data_dir, tmp_dir=tmp_dir, task_id=task_id) 
Example #12
Source File: image_transformer_2d_test.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def _test_img2img_transformer(self, net):
    batch_size = 3
    hparams = image_transformer_2d.img2img_transformer2d_tiny()
    hparams.data_dir = ""
    p_hparams = registry.problem("image_celeba").get_hparams(hparams)
    inputs = np.random.randint(256, size=(batch_size, 4, 4, 3))
    targets = np.random.randint(256, size=(batch_size, 8, 8, 3))
    with self.test_session() as session:
      features = {
          "inputs": tf.constant(inputs, dtype=tf.int32),
          "targets": tf.constant(targets, dtype=tf.int32),
          "target_space_id": tf.constant(1, dtype=tf.int32),
      }
      model = net(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
      logits, _ = model(features)
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (batch_size, 8, 8, 3, 256)) 
Example #13
Source File: slicenet_test.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def testSliceNet(self):
    x = np.random.randint(256, size=(3, 5, 5, 3))
    y = np.random.randint(10, size=(3, 5, 1, 1))
    hparams = slicenet.slicenet_params1_tiny()
    hparams.add_hparam("data_dir", "")
    problem = registry.problem("image_cifar10")
    p_hparams = problem.get_hparams(hparams)
    hparams.problem_hparams = p_hparams
    with self.test_session() as session:
      features = {
          "inputs": tf.constant(x, dtype=tf.int32),
          "targets": tf.constant(y, dtype=tf.int32),
          "target_space_id": tf.constant(1, dtype=tf.int32),
      }
      model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN,
                                p_hparams)
      logits, _ = model(features)
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (3, 1, 1, 1, 10)) 
Example #14
Source File: t2t_trainer.py    From fine-lm with MIT License 6 votes vote down vote up
def create_experiment_fn(**kwargs):
  return trainer_lib.create_experiment_fn(
      model_name=FLAGS.model,
      problem_name=FLAGS.problem,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      train_steps=FLAGS.train_steps,
      eval_steps=FLAGS.eval_steps,
      min_eval_frequency=FLAGS.local_eval_frequency,
      schedule=FLAGS.schedule,
      eval_throttle_seconds=FLAGS.eval_throttle_seconds,
      export=FLAGS.export_saved_model,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
      use_tfdbg=FLAGS.tfdbg,
      use_dbgprofile=FLAGS.dbgprofile,
      eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
      eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
      eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
      eval_early_stopping_metric_minimize=FLAGS.
      eval_early_stopping_metric_minimize,
      use_tpu=FLAGS.use_tpu,
      **kwargs) 
Example #15
Source File: t2t_datagen.py    From fine-lm with MIT License 6 votes vote down vote up
def generate_data_for_registered_problem(problem_name):
  """Generate data for a registered problem."""
  tf.logging.info("Generating data for %s.", problem_name)
  if FLAGS.num_shards:
    raise ValueError("--num_shards should not be set for registered Problem.")
  problem = registry.problem(problem_name)
  task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
  if task_id is None and problem.multiprocess_generate:
    if FLAGS.task_id_start != -1:
      assert FLAGS.task_id_end != -1
      task_id_start = FLAGS.task_id_start
      task_id_end = FLAGS.task_id_end
    else:
      task_id_start = 0
      task_id_end = problem.num_generate_tasks
    pool = multiprocessing.Pool(processes=FLAGS.num_concurrent_processes)
    problem.prepare_to_generate(data_dir, tmp_dir)
    args = [(problem_name, data_dir, tmp_dir, task_id)
            for task_id in range(task_id_start, task_id_end)]
    pool.map(generate_data_in_process, args)
  else:
    problem.generate_data(data_dir, tmp_dir, task_id) 
Example #16
Source File: t2t_datagen.py    From fine-lm with MIT License 6 votes vote down vote up
def generate_data_for_problem(problem):
  """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
  training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

  num_shards = FLAGS.num_shards or 10
  tf.logging.info("Generating training data for %s.", problem)
  train_output_files = generator_utils.train_data_filenames(
      problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
  generator_utils.generate_files(training_gen(), train_output_files,
                                 FLAGS.max_cases)
  tf.logging.info("Generating development data for %s.", problem)
  dev_output_files = generator_utils.dev_data_filenames(
      problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, 1)
  generator_utils.generate_files(dev_gen(), dev_output_files)
  all_output_files = train_output_files + dev_output_files
  generator_utils.shuffle_dataset(all_output_files) 
Example #17
Source File: model_rl_experiment.py    From fine-lm with MIT License 6 votes vote down vote up
def combine_training_data(problem, final_data_dir, old_data_dirs,
                          copy_last_eval_set=True):
  """Add training data from old_data_dirs into final_data_dir."""
  for i, data_dir in enumerate(old_data_dirs):
    suffix = os.path.basename(data_dir)
    # Glob train files in old data_dir
    old_train_files = tf.gfile.Glob(
        problem.filepattern(data_dir, tf.estimator.ModeKeys.TRAIN))
    if (i + 1) == len(old_data_dirs) and copy_last_eval_set:
      old_train_files += tf.gfile.Glob(
          problem.filepattern(data_dir, tf.estimator.ModeKeys.EVAL))
    for fname in old_train_files:
      # Move them to the new data_dir with a suffix
      # Since the data is read based on a prefix filepattern, adding the suffix
      # should be fine.
      new_fname = os.path.join(final_data_dir,
                               os.path.basename(fname) + "." + suffix)
      if not tf.gfile.Exists(new_fname):
        tf.gfile.Copy(fname, new_fname) 
Example #18
Source File: slicenet_test.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def testSliceNetImageToText(self):
    x = np.random.randint(256, size=(3, 5, 5, 3))
    y = np.random.randint(10, size=(3, 5, 1, 1))
    hparams = slicenet.slicenet_params1_tiny()
    hparams.add_hparam("data_dir", "")
    problem = registry.problem("image_ms_coco_characters")
    p_hparams = problem.get_hparams(hparams)
    hparams.problem_hparams = p_hparams
    with self.test_session() as session:
      features = {
          "inputs": tf.constant(x, dtype=tf.int32),
          "targets": tf.constant(y, dtype=tf.int32),
          "target_space_id": tf.constant(1, dtype=tf.int32),
      }
      model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN,
                                p_hparams)
      logits, _ = model(features)
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (3, 5, 1, 1, 258)) 
Example #19
Source File: model_rl_experiment.py    From fine-lm with MIT License 6 votes vote down vote up
def generate_real_env_data(problem_name, agent_policy_path, hparams, data_dir,
                           tmp_dir, autoencoder_path=None, eval_phase=False):
  """Run the agent against the real environment and return mean reward."""
  tf.gfile.MakeDirs(data_dir)
  with temporary_flags({
      "problem": problem_name,
      "agent_policy_path": agent_policy_path,
      "autoencoder_path": autoencoder_path,
  }):
    gym_problem = registry.problem(problem_name)
    gym_problem.settable_num_steps = hparams.true_env_generator_num_steps
    gym_problem.settable_eval_phase = eval_phase
    gym_problem.generate_data(data_dir, tmp_dir)
    mean_reward = None
    if gym_problem.statistics.number_of_dones:
      mean_reward = (gym_problem.statistics.sum_of_rewards /
                     gym_problem.statistics.number_of_dones)

  return mean_reward 
Example #20
Source File: model_rl_experiment.py    From fine-lm with MIT License 6 votes vote down vote up
def evaluate_world_model(simulated_problem_name, problem_name, hparams,
                         world_model_dir, epoch_data_dir, tmp_dir):
  """Generate simulated environment data and return reward accuracy."""
  gym_simulated_problem = registry.problem(simulated_problem_name)
  sim_steps = hparams.simulated_env_generator_num_steps
  gym_simulated_problem.settable_num_steps = sim_steps
  with temporary_flags({
      "problem": problem_name,
      "model": hparams.generative_model,
      "hparams_set": hparams.generative_model_params,
      "data_dir": epoch_data_dir,
      "output_dir": world_model_dir,
  }):
    gym_simulated_problem.generate_data(epoch_data_dir, tmp_dir)
  n = max(1., gym_simulated_problem.statistics.number_of_dones)
  model_reward_accuracy = (
      gym_simulated_problem.statistics.successful_episode_reward_predictions
      / float(n))
  old_path = os.path.join(epoch_data_dir, "debug_frames_sim")
  new_path = os.path.join(epoch_data_dir, "debug_frames_sim_eval")
  if not tf.gfile.Exists(new_path):
    tf.gfile.Rename(old_path, new_path)
  return model_reward_accuracy 
Example #21
Source File: datagen_with_agent.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def main(_):

  tf.gfile.MakeDirs(FLAGS.data_dir)
  tf.gfile.MakeDirs(FLAGS.tmp_dir)

  # Create problem if not already defined
  problem_name = "gym_discrete_problem_with_agent_on_%s" % FLAGS.game
  if problem_name not in registry.Registries.problems:
    gym_env.register_game(FLAGS.game)

  # Generate
  tf.logging.info("Running %s environment for %d steps for trajectories.",
                  FLAGS.game, FLAGS.num_env_steps)
  problem = registry.problem(problem_name)
  problem.settable_num_steps = FLAGS.num_env_steps
  problem.settable_eval_phase = FLAGS.eval
  problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)

  # Log stats
  if problem.statistics.number_of_dones:
    mean_reward = (problem.statistics.sum_of_rewards /
                   problem.statistics.number_of_dones)
    tf.logging.info("Mean reward: %.2f, Num dones: %d",
                    mean_reward,
                    problem.statistics.number_of_dones) 
Example #22
Source File: problems_colab.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def problem(name):
  return registry.problem(name) 
Example #23
Source File: babi_qa.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def __init__(self, *args, **kwargs):

    super(BabiQa, self).__init__(*args, **kwargs)
    assert not self._was_reversed, "This problem is not reversible!"
    assert not self._was_copy, "This problem is not copyable!" 
Example #24
Source File: tests_utils.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def fill_hparams(hparams, in_frames, out_frames):
  hparams.video_num_input_frames = in_frames
  hparams.video_num_target_frames = out_frames
  problem = registry.problem("video_stochastic_shapes10k")
  p_hparams = problem.get_hparams(hparams)
  hparams.problem = problem
  hparams.problem_hparams = p_hparams
  hparams.tiny_mode = True
  hparams.reward_prediction = False
  return hparams 
Example #25
Source File: hparams_lib.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def copy_hparams(hparams):
  hp_vals = hparams.values()
  new_hparams = hparam.HParams(**hp_vals)
  other_attrs = ["problem", "problem_hparams"]
  for attr in other_attrs:
    attr_val = getattr(hparams, attr, None)
    if attr_val is not None:
      setattr(new_hparams, attr, attr_val)
  return new_hparams 
Example #26
Source File: t2t_decoder.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def create_hparams():
  hparams_path = None
  if FLAGS.output_dir:
    hparams_path = os.path.join(FLAGS.output_dir, "hparams.json")
  return trainer_lib.create_hparams(
      FLAGS.hparams_set,
      FLAGS.hparams,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      problem_name=FLAGS.problem,
      hparams_path=hparams_path) 
Example #27
Source File: t2t_trainer.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def generate_data():
  # Generate data if requested.
  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
  tf.gfile.MakeDirs(data_dir)
  tf.gfile.MakeDirs(tmp_dir)

  problem_name = FLAGS.problem
  tf.logging.info("Generating data for %s" % problem_name)
  registry.problem(problem_name).generate_data(data_dir, tmp_dir) 
Example #28
Source File: problems_colab.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def available():
  return sorted(registry.list_problems())


# Import problem modules 
Example #29
Source File: t2t_attack.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def prepare_data(problem, hparams, params, config):
  """Construct input pipeline."""
  input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams, force_repeat=True)
  dataset = input_fn(params, config)
  features, _ = dataset.make_one_shot_iterator().get_next()
  inputs, labels = features["targets"], features["inputs"]
  inputs = tf.to_float(inputs)
  input_shape = inputs.shape.as_list()
  inputs = tf.reshape(inputs, [hparams.batch_size] + input_shape[1:])
  labels = tf.reshape(labels, [hparams.batch_size])
  return inputs, labels, features 
Example #30
Source File: t2t_datagen.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def generate_data_in_process(arg):
  problem_name, data_dir, tmp_dir, task_id = arg
  problem = registry.problem(problem_name)
  problem.generate_data(data_dir, tmp_dir, task_id)