Python tensorflow.HParams() Examples

The following are 30 code examples of tensorflow.HParams(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: cifar10_main.py    From uai-sdk with Apache License 2.0 6 votes vote down vote up
def main(output_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  # UAI SDK use --output_dir as model_dir
  # UAI SDK use --data_dir as data_dir
  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=output_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(
          is_chief=config.is_chief,
          **hparams)) 
Example #2
Source File: model_hparams.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #3
Source File: policies.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def __init__(self,
               task_config,
               model_hparams=None,
               embedder_hparams=None,
               train_hparams=None):
    """Constructs a policy which knows how to work with tasks (see tasks.py).

    It allows to read task history, goal and outputs in consistency with the
    task config.

    Args:
      task_config: an object of type tasks.TaskIOConfig (see tasks.py)
      model_hparams: a tf.HParams object containing parameter pertaining to
        model (these are implementation specific)
      embedder_hparams: a tf.HParams object containing parameter pertaining to
        history, goal embedders (these are implementation specific)
      train_hparams: a tf.HParams object containing parameter pertaining to
        trainin (these are implementation specific)`
    """
    super(TaskPolicy, self).__init__(None, None)
    self._model_hparams = model_hparams
    self._embedder_hparams = embedder_hparams
    self._train_hparams = train_hparams
    self._task_config = task_config
    self._extra_train_ops = [] 
Example #4
Source File: policies.py    From models with Apache License 2.0 6 votes vote down vote up
def __init__(self,
               task_config,
               model_hparams=None,
               embedder_hparams=None,
               train_hparams=None):
    """Constructs a policy which knows how to work with tasks (see tasks.py).

    It allows to read task history, goal and outputs in consistency with the
    task config.

    Args:
      task_config: an object of type tasks.TaskIOConfig (see tasks.py)
      model_hparams: a tf.HParams object containing parameter pertaining to
        model (these are implementation specific)
      embedder_hparams: a tf.HParams object containing parameter pertaining to
        history, goal embedders (these are implementation specific)
      train_hparams: a tf.HParams object containing parameter pertaining to
        trainin (these are implementation specific)`
    """
    super(TaskPolicy, self).__init__(None, None)
    self._model_hparams = model_hparams
    self._embedder_hparams = embedder_hparams
    self._train_hparams = train_hparams
    self._task_config = task_config
    self._extra_train_ops = [] 
Example #5
Source File: cifar10_main.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def main(job_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=job_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(
          is_chief=config.is_chief,
          **hparams)) 
Example #6
Source File: model_hparams.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #7
Source File: policies.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def __init__(self,
               task_config,
               model_hparams=None,
               embedder_hparams=None,
               train_hparams=None):
    """Constructs a policy which knows how to work with tasks (see tasks.py).

    It allows to read task history, goal and outputs in consistency with the
    task config.

    Args:
      task_config: an object of type tasks.TaskIOConfig (see tasks.py)
      model_hparams: a tf.HParams object containing parameter pertaining to
        model (these are implementation specific)
      embedder_hparams: a tf.HParams object containing parameter pertaining to
        history, goal embedders (these are implementation specific)
      train_hparams: a tf.HParams object containing parameter pertaining to
        trainin (these are implementation specific)`
    """
    super(TaskPolicy, self).__init__(None, None)
    self._model_hparams = model_hparams
    self._embedder_hparams = embedder_hparams
    self._train_hparams = train_hparams
    self._task_config = task_config
    self._extra_train_ops = [] 
Example #8
Source File: model_hparams.py    From MAX-Object-Detector with Apache License 2.0 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #9
Source File: cifar10_main.py    From object_detection_with_tensorflow with MIT License 6 votes vote down vote up
def main(job_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=job_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(
          is_chief=config.is_chief,
          **hparams)) 
Example #10
Source File: model_hparams.py    From Elphas with Apache License 2.0 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
    """Returns hyperparameters, including any flag value overrides.

    Args:
      hparams_overrides: Optional hparams overrides, represented as a
        string containing comma-separated hparam_name=value pairs.

    Returns:
      The hyperparameters as a tf.HParams object.
    """
    hparams = tf.contrib.training.HParams(
        # Whether a fine tuning checkpoint (provided in the pipeline config)
        # should be loaded for training.
        load_pretrained=True)
    # Override any of the preceding hyperparameter values.
    if hparams_overrides:
        hparams = hparams.parse(hparams_overrides)
    return hparams 
Example #11
Source File: cifar10_main.py    From object_detection_kitti with Apache License 2.0 6 votes vote down vote up
def main(job_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=job_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(**hparams)) 
Example #12
Source File: cifar10_main.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 6 votes vote down vote up
def main(job_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=job_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(
          is_chief=config.is_chief,
          **hparams)) 
Example #13
Source File: tf_utils.py    From synvae with MIT License 6 votes vote down vote up
def merge_hparams(hparams_1, hparams_2):
  """Merge hyperparameters from two tf.HParams objects.

  If the same key is present in both HParams objects, the value from `hparams_2`
  will be used.

  Args:
    hparams_1: The first tf.HParams object to merge.
    hparams_2: The second tf.HParams object to merge.

  Returns:
    A merged tf.HParams object with the hyperparameters from both `hparams_1`
    and `hparams_2`.
  """
  hparams_map = hparams_1.values()
  hparams_map.update(hparams_2.values())
  return tf.contrib.training.HParams(**hparams_map) 
Example #14
Source File: model_hparams.py    From vehicle_counting_tensorflow with MIT License 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #15
Source File: model_hparams.py    From BMW-TensorFlow-Training-GUI with Apache License 2.0 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #16
Source File: model_hparams.py    From ros_tensorflow with Apache License 2.0 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #17
Source File: cifar10_main.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def main(job_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=job_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(
          is_chief=config.is_chief,
          **hparams)) 
Example #18
Source File: model_hparams.py    From ros_people_object_detection_tensorflow with Apache License 2.0 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #19
Source File: model_hparams.py    From Person-Detection-and-Tracking with MIT License 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #20
Source File: model_hparams.py    From Traffic-Rule-Violation-Detection-System with MIT License 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #21
Source File: model_hparams.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
Example #22
Source File: cifar10_main.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def main(job_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=job_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(
          is_chief=config.is_chief,
          **hparams)) 
Example #23
Source File: mpnn.py    From mpnn with Apache License 2.0 5 votes vote down vote up
def __init__(self, hparams):
    """GRU update function used in GG-NN.

    Implements h_v^{t+1} = GRU(h_v^t, m_v^{t+1}).

    Args:
      hparams (tf.HParams object): only relevant hparam is node_dim which is the
        dimension of the node states.
    """
    super(GRUUpdate, self).__init__(hparams)
    self.node_dim = hparams.node_dim

    self.init_fprop() 
Example #24
Source File: amoeba_net_model.py    From tpu_models with Apache License 2.0 5 votes vote down vote up
def formatted_hparams(hparams):
  """Formatts the hparams into a readable string.

  Also looks for attributes that have not correctly been added to the hparams
  and prints the keys as "bad keys". These bad keys may be left out of iterators
  and cirumvent type checking.

  Args:
    hparams: an HParams instance.

  Returns:
    A string.
  """
  # Look for bad keys (see docstring).
  good_keys = set(hparams.values().keys())
  bad_keys = []
  for key in hparams.__dict__:
    if key not in good_keys and not key.startswith('_'):
      bad_keys.append(key)
  bad_keys.sort()

  # Format hparams.
  readable_items = [
      '%s: %s' % (k, v) for k, v in sorted(hparams.values().iteritems())]
  readable_items.append('Bad keys: %s' % ','.join(bad_keys))
  readable_string = ('\n'.join(readable_items))
  return readable_string 
Example #25
Source File: multi_problem_v2.py    From BERT with Apache License 2.0 5 votes vote down vote up
def dataset(self, mode, hparams=None, global_step=None, **kwargs):
    """Returns a dataset containing examples from multiple problems.

    Args:
      mode: A member of problem.DatasetSplit.
      hparams: A tf.HParams object, the model hparams.
      global_step: A scalar tensor used to compute the sampling distribution.
        If global_step is None, we call tf.train.get_or_create_global_step by
        default.
      **kwargs: Keywords for problem.Problem.Dataset.

    Returns:
      A dataset containing examples from multiple problems.
    """
    datasets = [p.dataset(mode, **kwargs) for p in self.problems]
    datasets = [
        d.map(lambda x, i=j: self.normalize_example(  # pylint: disable=g-long-lambda
            dict(x, problem_id=tf.constant([i])), hparams))
        for j, d in enumerate(datasets)  # Tag examples with a problem_id.
    ]
    if mode is problem.DatasetSplit.TRAIN:
      if global_step is None:
        global_step = tf.train.get_or_create_global_step()
      pmf = get_schedule_distribution(self.schedule, global_step)
      return get_multi_dataset(datasets, pmf)
    elif self.only_eval_first_problem:
      return datasets[0]
    else:
      datasets = [d.repeat() for d in datasets]
      return tf.data.Dataset.zip(tuple(datasets)).flat_map(
          lambda *x: functools.reduce(  # pylint: disable=g-long-lambda
              tf.data.Dataset.concatenate,
              map(tf.data.Dataset.from_tensors, x))) 
Example #26
Source File: shake_shake.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def build_shake_shake_model(images, num_classes, hparams, is_training):
  """Builds the Shake-Shake model.

  Build the Shake-Shake model from https://arxiv.org/abs/1705.07485.

  Args:
    images: Tensor of images that will be fed into the Wide ResNet Model.
    num_classes: Number of classed that the model needs to predict.
    hparams: tf.HParams object that contains additional hparams needed to
      construct the model. In this case it is the `shake_shake_widen_factor`
      that is used to determine how many filters the model has.
    is_training: Is the model training or not.

  Returns:
    The logits of the Shake-Shake model.
  """
  depth = 26
  k = hparams.shake_shake_widen_factor  # The widen factor
  n = int((depth - 2) / 6)
  x = images

  x = ops.conv2d(x, 16, 3, scope='init_conv')
  x = ops.batch_norm(x, scope='init_bn')
  with tf.variable_scope('L1'):
    x = _shake_shake_layer(x, 16 * k, n, 1, is_training)
  with tf.variable_scope('L2'):
    x = _shake_shake_layer(x, 32 * k, n, 2, is_training)
  with tf.variable_scope('L3'):
    x = _shake_shake_layer(x, 64 * k, n, 2, is_training)
  x = tf.nn.relu(x)
  x = ops.global_avg_pool(x)

  # Fully connected
  logits = ops.fc(x, num_classes)
  return logits 
Example #27
Source File: shake_shake.py    From models with Apache License 2.0 5 votes vote down vote up
def build_shake_shake_model(images, num_classes, hparams, is_training):
  """Builds the Shake-Shake model.

  Build the Shake-Shake model from https://arxiv.org/abs/1705.07485.

  Args:
    images: Tensor of images that will be fed into the Wide ResNet Model.
    num_classes: Number of classed that the model needs to predict.
    hparams: tf.HParams object that contains additional hparams needed to
      construct the model. In this case it is the `shake_shake_widen_factor`
      that is used to determine how many filters the model has.
    is_training: Is the model training or not.

  Returns:
    The logits of the Shake-Shake model.
  """
  depth = 26
  k = hparams.shake_shake_widen_factor  # The widen factor
  n = int((depth - 2) / 6)
  x = images

  x = ops.conv2d(x, 16, 3, scope='init_conv')
  x = ops.batch_norm(x, scope='init_bn')
  with tf.variable_scope('L1'):
    x = _shake_shake_layer(x, 16 * k, n, 1, is_training)
  with tf.variable_scope('L2'):
    x = _shake_shake_layer(x, 32 * k, n, 2, is_training)
  with tf.variable_scope('L3'):
    x = _shake_shake_layer(x, 64 * k, n, 2, is_training)
  x = tf.nn.relu(x)
  x = ops.global_avg_pool(x)

  # Fully connected
  logits = ops.fc(x, num_classes)
  return logits 
Example #28
Source File: mpnn.py    From mpnn with Apache License 2.0 5 votes vote down vote up
def __init__(self, hparams):
    """Constructor.

    Args:
      hparams: tf.HParams object.
    """
    self.hparams = hparams 
Example #29
Source File: mpnn.py    From mpnn with Apache License 2.0 5 votes vote down vote up
def __init__(self, hparams):
    """Build all of the variables.

    Args:
      hparams: tf.HParams object, only node_dim is relevant to this function.
    """
    super(GGNNMsgPass, self).__init__(hparams)
    self.node_dim = hparams.node_dim
    # TODO(gilmer): sub class should just call
    # super(self, GGNNMsgPass).__init__()
    # NOTE: init_fprop will set two member variables of the class, a_in
    # and a_out, these will be overwritten the first time fprop is called.
    self.init_fprop() 
Example #30
Source File: mpnn.py    From mpnn with Apache License 2.0 5 votes vote down vote up
def __init__(self, hparams):
    """Build all of the variables.

    Args:
      hparams: tf.HParams object, only node_dim is relevant to this function.
    """
    super(EdgeNetwork, self).__init__(hparams)
    self.init_fprop()