Python absl.flags.multi_flags_validator() Examples

The following are 15 code examples of absl.flags.multi_flags_validator(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module absl.flags , or try the search function .
Example #1
Source File: _device.py    From models with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.compat.v1.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #2
Source File: _device.py    From nsfw with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #3
Source File: _device.py    From ml-on-gcp with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #4
Source File: _device.py    From ml-on-gcp with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #5
Source File: _device.py    From ml-on-gcp with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #6
Source File: _device.py    From ml-on-gcp with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #7
Source File: _device.py    From models with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #8
Source File: _device.py    From models with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #9
Source File: _device.py    From models with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #10
Source File: _device.py    From Live-feed-object-device-identification-using-Tensorflow-and-OpenCV with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.compat.v1.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #11
Source File: _device.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #12
Source File: _device.py    From models with Apache License 2.0 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        logging.error("%s must be a GCS path.", key)
        valid_flags = False

    return valid_flags 
Example #13
Source File: _device.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def require_cloud_storage(flag_names):
  """Register a validator to check directory flags.
  Args:
    flag_names: An iterable of strings containing the names of flags to be
      checked.
  """
  msg = "TPU requires GCS path for {}".format(", ".join(flag_names))
  @flags.multi_flags_validator(["tpu"] + flag_names, message=msg)
  def _path_check(flag_values):  # pylint: disable=missing-docstring
    if flag_values["tpu"] is None:
      return True

    valid_flags = True
    for key in flag_names:
      if not flag_values[key].startswith("gs://"):
        tf.logging.error("{} must be a GCS path.".format(key))
        valid_flags = False

    return valid_flags 
Example #14
Source File: benchmark_main.py    From g-tensorflow-models with Apache License 2.0 4 votes vote down vote up
def define_keras_benchmark_flags():
  """Add flags for keras built-in application models."""
  flags_core.define_base(hooks=False)
  flags_core.define_performance()
  flags_core.define_image()
  flags_core.define_benchmark()
  flags.adopt_module_key_flags(flags_core)

  flags_core.set_defaults(
      data_format="channels_last",
      use_synthetic_data=True,
      batch_size=32,
      train_epochs=2)

  flags.DEFINE_enum(
      name="model", default=None,
      enum_values=MODELS.keys(), case_sensitive=False,
      help=flags_core.help_wrap(
          "Model to be benchmarked."))

  flags.DEFINE_integer(
      name="num_train_images", default=1000,
      help=flags_core.help_wrap(
          "The number of synthetic images for training. The default value is "
          "1000."))

  flags.DEFINE_integer(
      name="num_eval_images", default=50,
      help=flags_core.help_wrap(
          "The number of synthetic images for evaluation. The default value is "
          "50."))

  flags.DEFINE_boolean(
      name="eager", default=False, help=flags_core.help_wrap(
          "To enable eager execution. Note that if eager execution is enabled, "
          "only one GPU is utilized even if multiple GPUs are provided and "
          "multi_gpu_model is used."))

  flags.DEFINE_boolean(
      name="dist_strat", default=False, help=flags_core.help_wrap(
          "To enable distribution strategy for model training and evaluation. "
          "Number of GPUs used for distribution strategy can be set by the "
          "argument --num_gpus."))

  flags.DEFINE_list(
      name="callbacks",
      default=["ExamplesPerSecondCallback", "LoggingMetricCallback"],
      help=flags_core.help_wrap(
          "A list of (case insensitive) strings to specify the names of "
          "callbacks. For example: `--callbacks ExamplesPerSecondCallback,"
          "LoggingMetricCallback`"))

  @flags.multi_flags_validator(
      ["eager", "dist_strat"],
      message="Both --eager and --dist_strat were set. Only one can be "
              "defined, as DistributionStrategy is not supported in Eager "
              "execution currently.")
  # pylint: disable=unused-variable
  def _check_eager_dist_strat(flag_dict):
    return not(flag_dict["eager"] and flag_dict["dist_strat"]) 
Example #15
Source File: benchmark_main.py    From multilabel-image-classification-tensorflow with MIT License 4 votes vote down vote up
def define_keras_benchmark_flags():
  """Add flags for keras built-in application models."""
  flags_core.define_base(hooks=False)
  flags_core.define_performance()
  flags_core.define_image()
  flags_core.define_benchmark()
  flags.adopt_module_key_flags(flags_core)

  flags_core.set_defaults(
      data_format="channels_last",
      use_synthetic_data=True,
      batch_size=32,
      train_epochs=2)

  flags.DEFINE_enum(
      name="model", default=None,
      enum_values=MODELS.keys(), case_sensitive=False,
      help=flags_core.help_wrap(
          "Model to be benchmarked."))

  flags.DEFINE_integer(
      name="num_train_images", default=1000,
      help=flags_core.help_wrap(
          "The number of synthetic images for training. The default value is "
          "1000."))

  flags.DEFINE_integer(
      name="num_eval_images", default=50,
      help=flags_core.help_wrap(
          "The number of synthetic images for evaluation. The default value is "
          "50."))

  flags.DEFINE_boolean(
      name="eager", default=False, help=flags_core.help_wrap(
          "To enable eager execution. Note that if eager execution is enabled, "
          "only one GPU is utilized even if multiple GPUs are provided and "
          "multi_gpu_model is used."))

  flags.DEFINE_boolean(
      name="dist_strat", default=False, help=flags_core.help_wrap(
          "To enable distribution strategy for model training and evaluation. "
          "Number of GPUs used for distribution strategy can be set by the "
          "argument --num_gpus."))

  flags.DEFINE_list(
      name="callbacks",
      default=["ExamplesPerSecondCallback", "LoggingMetricCallback"],
      help=flags_core.help_wrap(
          "A list of (case insensitive) strings to specify the names of "
          "callbacks. For example: `--callbacks ExamplesPerSecondCallback,"
          "LoggingMetricCallback`"))

  @flags.multi_flags_validator(
      ["eager", "dist_strat"],
      message="Both --eager and --dist_strat were set. Only one can be "
              "defined, as DistributionStrategy is not supported in Eager "
              "execution currently.")
  # pylint: disable=unused-variable
  def _check_eager_dist_strat(flag_dict):
    return not(flag_dict["eager"] and flag_dict["dist_strat"])