Python absl.logging.warn() Examples

The following are 30 code examples of absl.logging.warn(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module absl.logging , or try the search function .
Example #1
Source File: util.py    From models with Apache License 2.0 6 votes vote down vote up
def get_vars_to_restore(ckpt=None):
  """Returns list of variables that should be saved/restored.

  Args:
    ckpt: Path to existing checkpoint.  If present, returns only the subset of
        variables that exist in given checkpoint.

  Returns:
    List of all variables that need to be saved/restored.
  """
  model_vars = tf.trainable_variables()
  # Add batchnorm variables.
  bn_vars = [v for v in tf.global_variables()
             if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name]
  model_vars.extend(bn_vars)
  model_vars = sorted(model_vars, key=lambda x: x.op.name)
  if ckpt is not None:
    ckpt_var_names = tf.contrib.framework.list_variables(ckpt)
    ckpt_var_names = [name for (name, unused_shape) in ckpt_var_names]
    for v in model_vars:
      if v.op.name not in ckpt_var_names:
        logging.warn('Missing var %s in checkpoint: %s', v.op.name,
                     os.path.basename(ckpt))
    model_vars = [v for v in model_vars if v.op.name in ckpt_var_names]
  return model_vars 
Example #2
Source File: util.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def get_vars_to_restore(ckpt=None):
  """Returns list of variables that should be saved/restored.

  Args:
    ckpt: Path to existing checkpoint.  If present, returns only the subset of
        variables that exist in given checkpoint.

  Returns:
    List of all variables that need to be saved/restored.
  """
  model_vars = tf.trainable_variables()
  # Add batchnorm variables.
  bn_vars = [v for v in tf.global_variables()
             if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name]
  model_vars.extend(bn_vars)
  model_vars = sorted(model_vars, key=lambda x: x.op.name)
  if ckpt is not None:
    ckpt_var_names = tf.contrib.framework.list_variables(ckpt)
    ckpt_var_names = [name for (name, unused_shape) in ckpt_var_names]
    for v in model_vars:
      if v.op.name not in ckpt_var_names:
        logging.warn('Missing var %s in checkpoint: %s', v.op.name,
                     os.path.basename(ckpt))
    model_vars = [v for v in model_vars if v.op.name in ckpt_var_names]
  return model_vars 
Example #3
Source File: env_problem.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def num_rewards(self):
    """Returns the number of distinct rewards.

    Returns:
      Returns None if the reward range is infinite or the processed rewards
      aren't discrete, otherwise returns the number of distinct rewards.
    """

    # Pre-conditions: reward range is finite.
    #               : processed rewards are discrete.
    if not self.is_reward_range_finite:
      logging.warn("Infinite reward range, `num_rewards returning None`")
      return None
    if not self.is_processed_rewards_discrete:
      logging.warn(
          "Processed rewards are not discrete, `num_rewards` returning None")
      return None

    min_reward, max_reward = self.reward_range
    return max_reward - min_reward + 1 
Example #4
Source File: dataset_factory.py    From models with Apache License 2.0 6 votes vote down vote up
def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
    """Construct a dataset end-to-end and return it using an optional strategy.

    Args:
      strategy: a strategy that, if passed, will distribute the dataset
        according to that strategy. If passed and `num_devices > 1`,
        `use_per_replica_batch_size` must be set to `True`.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
    if strategy:
      if strategy.num_replicas_in_sync != self.config.num_devices:
        logging.warn('Passed a strategy with %d devices, but expected'
                     '%d devices.',
                     strategy.num_replicas_in_sync,
                     self.config.num_devices)
      dataset = strategy.experimental_distribute_datasets_from_function(
          self._build)
    else:
      dataset = self._build()

    return dataset 
Example #5
Source File: engine.py    From dm_control with Apache License 2.0 6 votes vote down vote up
def check_invalid_state(self):
    """Checks whether the physics state is invalid at exit.

    Yields:
      None

    Raises:
      PhysicsError: if the simulation state is invalid at exit, unless this
        context is nested inside a `suppress_physics_errors` context, in which
        case a warning will be logged instead.
    """
    # `np.copyto(dst, src)` is marginally faster than `dst[:] = src`.
    np.copyto(self._warnings_before, self._warnings)
    yield
    np.greater(self._warnings, self._warnings_before, out=self._new_warnings)
    if any(self._new_warnings):
      warning_names = np.compress(self._new_warnings, enums.mjtWarning._fields)
      message = _INVALID_PHYSICS_STATE.format(
          warning_names=', '.join(warning_names))
      if self._warnings_cause_exception:
        raise _control.PhysicsError(message)
      else:
        logging.warn(message) 
Example #6
Source File: util.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def get_vars_to_restore(ckpt=None):
  """Returns list of variables that should be saved/restored.

  Args:
    ckpt: Path to existing checkpoint.  If present, returns only the subset of
        variables that exist in given checkpoint.

  Returns:
    List of all variables that need to be saved/restored.
  """
  model_vars = tf.trainable_variables()
  # Add batchnorm variables.
  bn_vars = [v for v in tf.global_variables()
             if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name]
  model_vars.extend(bn_vars)
  model_vars = sorted(model_vars, key=lambda x: x.op.name)
  if ckpt is not None:
    ckpt_var_names = tf.contrib.framework.list_variables(ckpt)
    ckpt_var_names = [name for (name, unused_shape) in ckpt_var_names]
    for v in model_vars:
      if v.op.name not in ckpt_var_names:
        logging.warn('Missing var %s in checkpoint: %s', v.op.name,
                     os.path.basename(ckpt))
    model_vars = [v for v in model_vars if v.op.name in ckpt_var_names]
  return model_vars 
Example #7
Source File: device_model.py    From loaner with Apache License 2.0 6 votes vote down vote up
def disassociate_tag(self, user_email, tag_name):
    """Disassociates a tag from a device.

    Args:
      user_email: str, the email of the user taking the action.
      tag_name: str, the name of the tag to be disassociated.

    Raises:
      ValueError: If the tag requested to be disassociated from the device is
        not currently associated with the device.
    """

    for tag_reference in self.tags:
      if tag_reference.tag.name == tag_name:
        self.tags.remove(tag_reference)
        self.put()
        self.stream_to_bq(
            user_email, 'Removed tag %s from device %s' %
            (tag_reference.tag.name, self.identifier))
        return
    logging.warn(
        'Tag with name %s is not associated with device %s',
        tag_name, self.identifier) 
Example #8
Source File: dataset.py    From tfx with Apache License 2.0 6 votes vote down vote up
def generate_raw_dataset(self, args):
    logging.warn(
        "Not actually regenerating the raw dataset.\n"
        "To regenerate the raw CSV dataset, see the TFX Chicago Taxi example "
        "for details as to how to do so. "
        "tfx/examples/chicago_taxi_pipeline/taxi_pipeline_kubeflow_gcp.py "
        "has the BigQuery query used to generate the dataset.\n"
        "After regenerating the raw CSV dataset, you should also regenerate "
        "the derived TFRecords dataset. You can do so by passing "
        "--generate_dataset_args=/path/to/csv_dataset.csv to "
        "regenerate_datasets.py.")

    if args:
      logging.info("Converting CSV at %s to TFRecords", args)
      self.convert_csv_to_tf_examples(args, self.dataset_path())
      logging.info("TFRecords written to %s", self.dataset_path()) 
Example #9
Source File: qa4mre.py    From datasets with Apache License 2.0 5 votes vote down vote up
def __init__(self, year, track='main', language='EN', **kwargs):
    """BuilderConfig for Qa4Mre.

    Args:
      year: string, year of dataset
      track: string, the task track from PATHS[year]['_TRACKS'].
      language: string, Acronym for language in the main task.
      **kwargs: keyword arguments forwarded to super.
    """
    if track.lower() not in PATHS[year]['_TRACKS']:
      raise ValueError(
          'Incorrect track. Track should be one of the following: ',
          PATHS[year]['_TRACKS'])

    if track.lower() != 'main' and language.upper() != 'EN':
      logging.warn('Only English documents available for pilot '
                   'tracks. Setting English by default.')
      language = 'EN'

    if track.lower() == 'main' and language.upper(
    ) not in PATHS[year]['_LANGUAGES_MAIN']:
      raise ValueError(
          'Incorrect language for the main track. Correct options: ',
          PATHS[year]['_LANGUAGES_MAIN'])

    self.year = year
    self.track = track.lower()
    self.lang = language.upper()

    name = self.year + '.' + self.track + '.' + self.lang

    description = _DESCRIPTION
    description += ('This configuration includes the {} track for {} language '
                    'in {} year.').format(self.track, self.lang, self.year)

    super(Qa4mreConfig, self).__init__(
        name=name,
        description=description,
        version=tfds.core.Version('0.1.0'),
        **kwargs) 
Example #10
Source File: io.py    From compare-mt with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _write_scores_to_csv(output_filename, scores):
  """Writes scores for each individual example to an output CSV file.

  Output file is a comma separated where each line has the format:
    id,score1,score2,score3,...

  The header row indicates the type of each score column.

  Args:
    output_filename: Name of file to write results to.
    scores: A list of dicts mapping each score_type to a Score object.
  """

  if len(scores) < 1:
    logging.warn("No scores to write")
    return
  rouge_types = sorted(scores[0].keys())

  logging.info("Writing results to %s.", output_filename)
  with _open(output_filename, "w") as out_file:
    out_file.write("id")
    for rouge_type in rouge_types:
      out_file.write(",{t}-P,{t}-R,{t}-F".format(t=rouge_type))
    out_file.write("\n")
    for i, result in enumerate(scores):
      out_file.write("%d" % i)
      for rouge_type in rouge_types:
        out_file.write(",%f,%f,%f" %
                       (result[rouge_type].precision, result[rouge_type].recall,
                        result[rouge_type].fmeasure))
      out_file.write("\n")
  logging.info("Finished writing results.") 
Example #11
Source File: io.py    From compare-mt with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _record_gen(filename, delimiter):
  """Opens file and yields records separated by delimiter."""
  with _open(filename) as f:
    records = f.read().split(delimiter)
  if records[-1]:
    # Need a final delimiter at end of file to be able to detect an empty last
    # record.
    logging.warn("Expected delimiter at end of file")
  else:
    records = records[:-1]
  for record in records:
    yield record 
Example #12
Source File: tf_utils.py    From hub with Apache License 2.0 5 votes vote down vote up
def garbage_collect_exports(export_dir_base, exports_to_keep):
  """Deletes older exports, retaining only a given number of the most recent.

  Export subdirectories are assumed to be named with monotonically increasing
  integers; the most recent are taken to be those with the largest values.

  Args:
    export_dir_base: the base directory under which each export is in a
      versioned subdirectory.
    exports_to_keep: Number of exports to keep. Older exports will be garbage
      collected. Set to None to disable.
  """
  if exports_to_keep is None:
    return
  version_paths = []  # List of tuples (version, path)
  for filename in tf_v1.gfile.ListDirectory(export_dir_base):
    path = os.path.join(
        tf.compat.as_bytes(export_dir_base),
        tf.compat.as_bytes(filename))
    if len(filename) == 10 and filename.isdigit():
      version_paths.append((int(filename), path))

  oldest_version_path = sorted(version_paths)[:-exports_to_keep]
  for _, path in oldest_version_path:
    try:
      tf_v1.gfile.DeleteRecursively(path)
    except tf.errors.NotFoundError as e:
      logging.warn("Can not delete %s recursively: %s", path, e) 
Example #13
Source File: tf_utils.py    From hub with Apache License 2.0 5 votes vote down vote up
def get_timestamped_export_dir(export_dir_base):
  """Builds a path to a new subdirectory within the base directory.

  Each export is written into a new subdirectory named using the
  current time.  This guarantees monotonically increasing version
  numbers even across multiple runs of the pipeline.
  The timestamp used is the number of seconds since epoch UTC.

  Args:
    export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
  Returns:
    The full path of the new subdirectory (which is not actually created yet).

  Raises:
    RuntimeError: if repeated attempts fail to obtain a unique timestamped
      directory name.
  """
  attempts = 0
  while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
    export_timestamp = int(time.time())

    export_dir = os.path.join(
        tf.compat.as_bytes(export_dir_base),
        tf.compat.as_bytes(str(export_timestamp)))
    if not tf_v1.gfile.Exists(export_dir):
      # Collisions are still possible (though extremely unlikely): this
      # directory is not actually created yet, but it will be almost
      # instantly on return from this function.
      return export_dir
    time.sleep(1)
    attempts += 1
    logging.warn(
        "Export directory %s already exists; retrying (attempt %d/%d)",
        export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
  raise RuntimeError("Failed to obtain a unique export directory name after "
                     "%d attempts.".MAX_DIRECTORY_CREATION_ATTEMPTS) 
Example #14
Source File: juniper_test.py    From capirca with Apache License 2.0 5 votes vote down vote up
def testLongPolicer(self):
    with mock.patch.object(juniper.logging, 'warning',
                           spec=logging.warn) as warn:
      policy_text = GOOD_HEADER + LONG_POLICER_TERM_1
      jcl = juniper.Juniper(policy.ParsePolicy(policy_text, self.naming),
                            EXP_INFO)
      _ = str(jcl)
      warn.assert_called_with('WARNING: %s is longer than %d bytes. Due to'
                              ' limitation in JUNOS, OIDs longer than %dB'
                              ' can cause SNMP timeout issues.', 'this-is-very'
                              '-very-very-very-very-very-very-very-very-very'
                              '-very-very-very-very-very-very-very-very-very'
                              '-very-very-very-very-very-very-very-very-very'
                              '-very-very-long', 128, 128) 
Example #15
Source File: __init__.py    From language with Apache License 2.0 5 votes vote down vote up
def declare_relation(self,
                       rel_name,
                       domain_type,
                       range_type,
                       trainable = False,
                       dense = False):
    """Declare the domain and range types for a relation.

    Args:
      rel_name: string naming a relation
      domain_type: string naming the type of subject entities for the relation
      range_type: string naming the type of object entities for the relation
      trainable: boolean, true if the weights for this relation will be trained
      dense: if true, store data as a dense tensor instead of a SparseTensor

    Raises:
      RelationNameError: If a relation with this name already exists.
    """
    if rel_name in self._declaration:
      raise RelationNameError(rel_name, 'Multiple declarations for relation.')
    reserved = dir(NeuralQueryExpression)
    if rel_name in reserved:
      logging.warn(
          'rel_name prohibits expr.%s() as it matches a reserved word in: %r',
          rel_name, reserved)
    self._declaration[rel_name] = RelationDeclaration(rel_name, domain_type,
                                                      range_type, trainable,
                                                      dense)
    for type_name in [domain_type, range_type]:
      if type_name not in self._symtab:
        self._symtab[type_name] = symbol.SymbolTable()
    self._rel_name_symtab.insert(rel_name) 
Example #16
Source File: tensorspec_utils.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def is_encoded_image_spec(tensor_spec):
  """Determines whether the passed tensor_spec speficies an encoded image."""
  if hasattr(tensor_spec, 'data_format'):
    # If tensor_spec is an ExtendedTensorSpec, use the data_format to check.
    return (tensor_spec.data_format is not None) and (
        tensor_spec.data_format.upper() in ['JPEG', 'PNG'])
  else:
    # Otherwise default to the old "name contains 'image'" logic.
    logging.warn('Using a deprecated tensor specification. '
                 'Use ExtendedTensorSpec.')
    return 'image' in tensor_spec.name 
Example #17
Source File: tensorspec_utils.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def map_feed_dict_unsafe(feature_placeholders_spec, np_inputs_spec):
  """Deprecated function to create a feed_dict to be passed to session.run.

  tensorspec_utils.map_feed_dict should be used instead.  map_feed_dict_unsafe
  does not check that there is actually any agreement between
  feature_placeholders_spec or np_inputs spec in terms of dtype, shape
  or additional unused attributes within np_inputs_spec.

  Args:
    feature_placeholders_spec: An TensorSpecStruct containing
      {str: tf.placeholder}.
    np_inputs_spec: The numpy input according to the same spec.

  Returns:
    A mapping {placeholder: np.ndarray} which can be fed to a tensorflow
      session.run.
  """
  logging.warning('map_feed_dict_unsafe is deprecated. '
                  'Please update to map_feed_dict.')
  flat_spec = flatten_spec_structure(feature_placeholders_spec)
  flat_np_inputs = flatten_spec_structure(np_inputs_spec)
  for key, value in flat_np_inputs.items():
    if key not in flat_spec:
      logging.warn(
          'np_inputs has an input: %s, not found in the tensorspec.', key)
  feed_dict = {}
  for key, value in flat_spec.items():
    feed_dict[value] = flat_np_inputs[key]
  return feed_dict 
Example #18
Source File: fastlin.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def apply_increasing_monotonic_fn(self, wrapper, fn, *args, **parameters):
    if fn.__name__ != 'relu':
      # Fallback to regular interval bound propagation for unsupported
      # operations.
      logging.warn('"%s" is not supported by SymbolicBounds. '
                   'Fallback on IntervalBounds.', fn.__name__)
      interval_bounds = basic_bounds.IntervalBounds.convert(self)
      converted_args = [basic_bounds.IntervalBounds.convert(b) for b in args]
      interval_bounds = interval_bounds._increasing_monotonic_fn(  # pylint: disable=protected-access
          fn, *converted_args)
      return self.convert(interval_bounds)

    concrete = self.concretize()
    lb, ub = concrete.lower, concrete.upper
    is_ambiguous = tf.logical_and(ub > 0, lb < 0)
    # Ensure denominator is always positive, even when not needed.
    ambiguous_denom = tf.where(is_ambiguous, ub - lb, tf.ones_like(ub))
    scale = tf.where(
        is_ambiguous, ub / ambiguous_denom,
        tf.where(lb >= 0, tf.ones_like(lb), tf.zeros_like(lb)))
    bias = tf.where(is_ambiguous, -lb, tf.zeros_like(lb))
    lb_out = LinearExpression(
        w=tf.expand_dims(scale, 1) * self.lower.w,
        b=scale * self.lower.b,
        lower=self.lower.lower, upper=self.lower.upper)
    ub_out = LinearExpression(
        w=tf.expand_dims(scale, 1) * self.upper.w,
        b=scale * (self.upper.b + bias),
        lower=self.upper.lower, upper=self.upper.upper)
    return SymbolicBounds(lb_out, ub_out).with_priors(wrapper.output_bounds) 
Example #19
Source File: verifiable_wrapper.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _propagate_through(self, module, input_bounds):
    if isinstance(module, layers.BatchNorm):
      # This IBP-specific batch-norm implementation exposes stats recorded
      # the most recent time the BatchNorm module was connected.
      # These will be either the batch stats (e.g. if training) or the moving
      # averages, depending on how the module was called.
      mean = module.mean
      variance = module.variance
      epsilon = module.epsilon
      scale = module.scale
      bias = module.bias

    else:
      # This plain Sonnet batch-norm implementation only exposes the
      # moving averages.
      logging.warn('Sonnet BatchNorm module encountered: %s. '
                   'IBP will always use its moving averages, not the local '
                   'batch stats, even in training mode.', str(module))
      mean = module.moving_mean
      variance = module.moving_variance
      epsilon = module._eps  # pylint: disable=protected-access
      try:
        bias = module.beta
      except snt.Error:
        bias = None
      try:
        scale = module.gamma
      except snt.Error:
        scale = None

    return input_bounds.apply_batch_norm(self, mean, variance,
                                         scale, bias, epsilon) 
Example #20
Source File: model.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _observer(self, subgraph):
    input_nodes = self._inputs_for_observed_module(subgraph)
    if input_nodes is None:
      # We do not fail as we want to allow higher-level Sonnet components.
      # In practice, the rest of the logic will fail if we are unable to
      # connect all low-level modules.
      logging.warn('Unprocessed module "%s"', str(subgraph.module))
      return
    if subgraph.outputs in input_nodes:
      # The Sonnet module is just returning its input as its output.
      # This may happen with a reshape in which the shape does not change.
      return

    self._add_module(self._wrapper_for_observed_module(subgraph),
                     subgraph.outputs, *input_nodes) 
Example #21
Source File: rnn_agent.py    From ml-fairness-gym with Apache License 2.0 5 votes vote down vote up
def _choose_rec_from_softmax(self, softmax_probs, deterministic):
    if deterministic:
      rec = np.argmax(softmax_probs)
    else:
      # Fix the probability vector to avoid np.random.choice exception.
      softmax_probs = np.nan_to_num(softmax_probs)
      softmax_probs += 1e-10
      if not np.any(softmax_probs):
        logging.warn('All zeros in the softmax prediction.')
      softmax_probs = softmax_probs / np.sum(softmax_probs)
      # TODO(): Use epsilon for exploration at the model level.
      rec = self._rng.choice(self.action_space_size, p=softmax_probs)
    return rec

  # TODO(): Move the simulation function to a runner class. 
Example #22
Source File: sc2_env.py    From pysc2 with Apache License 2.0 5 votes vote down vote up
def _get_observations(self, target_game_loop):
    # Transform in the thread so it runs while waiting for other observations.
    def parallel_observe(c, f):
      obs = c.observe(target_game_loop=target_game_loop)
      agent_obs = f.transform_obs(obs)
      return obs, agent_obs

    with self._metrics.measure_observation_time():
      self._obs, self._agent_obs = zip(*self._parallel.run(
          (parallel_observe, c, f)
          for c, f in zip(self._controllers, self._features)))

    game_loop = self._agent_obs[0].game_loop[0]
    if (game_loop < target_game_loop and
        not any(o.player_result for o in self._obs)):
      raise ValueError(
          ("The game didn't advance to the expected game loop. "
           "Expected: %s, got: %s") % (target_game_loop, game_loop))
    elif game_loop > target_game_loop and target_game_loop > 0:
      logging.warn("Received observation %d step(s) late: %d rather than %d.",
                   game_loop - target_game_loop, game_loop, target_game_loop)

    if self._realtime:
      # Track delays on executed actions.
      # Note that this will underestimate e.g. action sent, new observation
      # taken before action executes, action executes, observation taken
      # with action. This is difficult to avoid without changing the SC2
      # binary - e.g. send the observation game loop with each action,
      # return them in the observation action proto.
      if self._last_obs_game_loop is not None:
        for i, obs in enumerate(self._obs):
          for action in obs.actions:
            if action.HasField("game_loop"):
              delay = action.game_loop - self._last_obs_game_loop
              if delay > 0:
                num_slots = len(self._action_delays[i])
                delay = min(delay, num_slots - 1)  # Cap to num buckets.
                self._action_delays[i][delay] += 1
                break
      self._last_obs_game_loop = game_loop 
Example #23
Source File: lm.py    From lamb with Apache License 2.0 5 votes vote down vote up
def _check_budget(self, config):
    num_trainables = utils.log_trainables()
    if config.num_params > -1:
      assert num_trainables <= config.num_params, (
          'The number of trainable parameters ({}) exceeds the budget ({}). '
          .format(num_trainables, config.num_params))
      if num_trainables < 0.98*(config.num_params-500):
        logging.warn('Number of parameters (%s) is way below the budget (%s)',
                     num_trainables, config.num_params) 
Example #24
Source File: run_rnnt.py    From rnnt-speech-recognition with MIT License 5 votes vote down vote up
def configure_environment(gpu_names,
                          fp16_run):

    if fp16_run:
        print('Using 16-bit float precision.')
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_policy(policy)

    gpus = tf.config.experimental.list_physical_devices('GPU')

    if gpu_names is not None and len(gpu_names) > 0:
        gpus = [x for x in gpus if x.name[len('/physical_device:'):] in gpu_names]

    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            # tf.config.experimental.set_virtual_device_configuration(
            #     gpus[0],
            #     [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096),
            #         tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPU,", len(logical_gpus), "Logical GPUs")
        except RuntimeError as e:
            logging.warn(str(e))

    if len(gpus) > 1:
        print('Running multi gpu: {}'.format(', '.join(gpu_names)))
        strategy = tf.distribute.MirroredStrategy(
            devices=gpu_names)
    else:
        device = gpus[0].name[len('/physical_device:'):]
        print('Running single gpu: {}'.format(device))
        strategy = tf.distribute.OneDeviceStrategy(
            device=device)

    dtype = tf.float16 if fp16_run else tf.float32

    return strategy, dtype 
Example #25
Source File: speech_cls_task.py    From delta with Apache License 2.0 4 votes vote down vote up
def get_class_files_duration(self):
    ''' dirnames under dataset is class name
     all data_path have same dirnames '''
    classes = None
    for root, dirnames, filenames in os.walk(self._data_path[0]):
      classes = dirnames
      break

    assert classes, 'can not acsess {}'.format(self._data_path[0])
    assert set(classes) == set(self._classes.keys()), '{} {}'.format(
        classes, self._classes.keys())

    def _get_class(path):
      ret = None
      for cls in self._classes:
        if cls in path:
          ret = cls
      return ret

    # to exclude some data under some dir
    excludes = []
    #pylint: disable=too-many-nested-blocks
    for data_path in self._data_path:
      logging.debug("data path: {}".format(data_path))
      for root, dirname, filenames in os.walk(data_path):
        del dirname
        for filename in filenames:
          if filename.endswith(self._file_suffix):
            class_name = _get_class(root)  # 'conflict' or 'normal' str
            assert class_name is not None
            filename = os.path.join(root, filename)

            if excludes:
              for exclude in excludes:
                if exclude in filename:
                  pass

            duration = self.get_duration(
                filename=filename, sr=self._sample_rate)
            self._class_file[class_name].append(
                (filename, duration, class_name))
          else:
            pass

    if not self._class_file:
      logging.debug("class file: {}".format(self._class_file))
      logging.warn("maybe the suffix {} file not exits".format(
          self._file_suffix)) 
Example #26
Source File: util.py    From multilabel-image-classification-tensorflow with MIT License 4 votes vote down vote up
def get_vars_to_save_and_restore(ckpt=None):
  """Returns list of variables that should be saved/restored.

  Args:
    ckpt: Path to existing checkpoint.  If present, returns only the subset of
        variables that exist in given checkpoint.

  Returns:
    List of all variables that need to be saved/restored.
  """
  model_vars = tf.trainable_variables()
  # Add batchnorm variables.
  bn_vars = [v for v in tf.global_variables()
             if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name or
             'mu' in v.op.name or 'sigma' in v.op.name or
             'global_scale_var' in v.op.name]
  model_vars.extend(bn_vars)
  model_vars = sorted(model_vars, key=lambda x: x.op.name)
  mapping = {}
  if ckpt is not None:
    ckpt_var = tf.contrib.framework.list_variables(ckpt)
    ckpt_var_names = [name for (name, unused_shape) in ckpt_var]
    ckpt_var_shapes = [shape for (unused_name, shape) in ckpt_var]
    not_loaded = list(ckpt_var_names)
    for v in model_vars:
      if v.op.name not in ckpt_var_names:
        # For backward compatibility, try additional matching.
        v_additional_name = v.op.name.replace('egomotion_prediction/', '')
        if v_additional_name in ckpt_var_names:
          # Check if shapes match.
          ind = ckpt_var_names.index(v_additional_name)
          if ckpt_var_shapes[ind] == v.get_shape():
            mapping[v_additional_name] = v
            not_loaded.remove(v_additional_name)
            continue
          else:
            logging.warn('Shape mismatch, will not restore %s.', v.op.name)
        logging.warn('Did not find var %s in checkpoint: %s', v.op.name,
                     os.path.basename(ckpt))
      else:
        # Check if shapes match.
        ind = ckpt_var_names.index(v.op.name)
        if ckpt_var_shapes[ind] == v.get_shape():
          mapping[v.op.name] = v
          not_loaded.remove(v.op.name)
        else:
          logging.warn('Shape mismatch, will not restore %s.', v.op.name)
    if not_loaded:
      logging.warn('The following variables in the checkpoint were not loaded:')
      for varname_not_loaded in not_loaded:
        logging.info('%s', varname_not_loaded)
  else:  # just get model vars.
    for v in model_vars:
      mapping[v.op.name] = v
  return mapping 
Example #27
Source File: util.py    From models with Apache License 2.0 4 votes vote down vote up
def get_vars_to_save_and_restore(ckpt=None):
  """Returns list of variables that should be saved/restored.

  Args:
    ckpt: Path to existing checkpoint.  If present, returns only the subset of
        variables that exist in given checkpoint.

  Returns:
    List of all variables that need to be saved/restored.
  """
  model_vars = tf.trainable_variables()
  # Add batchnorm variables.
  bn_vars = [v for v in tf.global_variables()
             if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name or
             'mu' in v.op.name or 'sigma' in v.op.name or
             'global_scale_var' in v.op.name]
  model_vars.extend(bn_vars)
  model_vars = sorted(model_vars, key=lambda x: x.op.name)
  mapping = {}
  if ckpt is not None:
    ckpt_var = tf.contrib.framework.list_variables(ckpt)
    ckpt_var_names = [name for (name, unused_shape) in ckpt_var]
    ckpt_var_shapes = [shape for (unused_name, shape) in ckpt_var]
    not_loaded = list(ckpt_var_names)
    for v in model_vars:
      if v.op.name not in ckpt_var_names:
        # For backward compatibility, try additional matching.
        v_additional_name = v.op.name.replace('egomotion_prediction/', '')
        if v_additional_name in ckpt_var_names:
          # Check if shapes match.
          ind = ckpt_var_names.index(v_additional_name)
          if ckpt_var_shapes[ind] == v.get_shape():
            mapping[v_additional_name] = v
            not_loaded.remove(v_additional_name)
            continue
          else:
            logging.warn('Shape mismatch, will not restore %s.', v.op.name)
        logging.warn('Did not find var %s in checkpoint: %s', v.op.name,
                     os.path.basename(ckpt))
      else:
        # Check if shapes match.
        ind = ckpt_var_names.index(v.op.name)
        if ckpt_var_shapes[ind] == v.get_shape():
          mapping[v.op.name] = v
          not_loaded.remove(v.op.name)
        else:
          logging.warn('Shape mismatch, will not restore %s.', v.op.name)
    if not_loaded:
      logging.warn('The following variables in the checkpoint were not loaded:')
      for varname_not_loaded in not_loaded:
        logging.info('%s', varname_not_loaded)
  else:  # just get model vars.
    for v in model_vars:
      mapping[v.op.name] = v
  return mapping 
Example #28
Source File: util.py    From g-tensorflow-models with Apache License 2.0 4 votes vote down vote up
def get_vars_to_save_and_restore(ckpt=None):
  """Returns list of variables that should be saved/restored.

  Args:
    ckpt: Path to existing checkpoint.  If present, returns only the subset of
        variables that exist in given checkpoint.

  Returns:
    List of all variables that need to be saved/restored.
  """
  model_vars = tf.trainable_variables()
  # Add batchnorm variables.
  bn_vars = [v for v in tf.global_variables()
             if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name or
             'mu' in v.op.name or 'sigma' in v.op.name or
             'global_scale_var' in v.op.name]
  model_vars.extend(bn_vars)
  model_vars = sorted(model_vars, key=lambda x: x.op.name)
  mapping = {}
  if ckpt is not None:
    ckpt_var = tf.contrib.framework.list_variables(ckpt)
    ckpt_var_names = [name for (name, unused_shape) in ckpt_var]
    ckpt_var_shapes = [shape for (unused_name, shape) in ckpt_var]
    not_loaded = list(ckpt_var_names)
    for v in model_vars:
      if v.op.name not in ckpt_var_names:
        # For backward compatibility, try additional matching.
        v_additional_name = v.op.name.replace('egomotion_prediction/', '')
        if v_additional_name in ckpt_var_names:
          # Check if shapes match.
          ind = ckpt_var_names.index(v_additional_name)
          if ckpt_var_shapes[ind] == v.get_shape():
            mapping[v_additional_name] = v
            not_loaded.remove(v_additional_name)
            continue
          else:
            logging.warn('Shape mismatch, will not restore %s.', v.op.name)
        logging.warn('Did not find var %s in checkpoint: %s', v.op.name,
                     os.path.basename(ckpt))
      else:
        # Check if shapes match.
        ind = ckpt_var_names.index(v.op.name)
        if ckpt_var_shapes[ind] == v.get_shape():
          mapping[v.op.name] = v
          not_loaded.remove(v.op.name)
        else:
          logging.warn('Shape mismatch, will not restore %s.', v.op.name)
    if not_loaded:
      logging.warn('The following variables in the checkpoint were not loaded:')
      for varname_not_loaded in not_loaded:
        logging.info('%s', varname_not_loaded)
  else:  # just get model vars.
    for v in model_vars:
      mapping[v.op.name] = v
  return mapping 
Example #29
Source File: estimator.py    From hub with Apache License 2.0 4 votes vote down vote up
def export(self, estimator, export_path, checkpoint_path=None,
             eval_result=None, is_the_final_export=None):
    """Actually performs the export of registered Modules.

    This method creates a timestamped directory under `export_path`
    with one sub-directory (named `export_name`) per module registered
    via `register_module_for_export`.

    Example use:

    ```python
      estimator = ... (Create estimator with modules registered for export)...
      exporter = hub.LatestModuleExporter("tf_hub", serving_input_fn)
      exporter.export(estimator, export_path, estimator.latest_checkpoint())
    ```

    Args:
      estimator: the `Estimator` from which to export modules.
      export_path: A string containing a directory where to write the export
        timestamped directories.
      checkpoint_path: The checkpoint path to export. If `None`,
        `estimator.latest_checkpoint()` is used.
      eval_result: Unused.
      is_the_final_export: Unused.

    Returns:
      The path to the created timestamped directory containing the exported
      modules.
    """
    if checkpoint_path is None:
      checkpoint_path = estimator.latest_checkpoint()

    export_dir = tf_utils.get_timestamped_export_dir(export_path)
    temp_export_dir = tf_utils.get_temp_export_dir(export_dir)

    session = _make_estimator_serving_session(estimator, self._serving_input_fn,
                                              checkpoint_path)
    with session:
      export_modules = tf_v1.get_collection(_EXPORT_MODULES_COLLECTION)
      if export_modules:
        for export_name, module in export_modules:
          module_export_path = os.path.join(temp_export_dir,
                                            tf.compat.as_bytes(export_name))
          module.export(module_export_path, session)
        tf_v1.gfile.Rename(temp_export_dir, export_dir)
        tf_utils.garbage_collect_exports(export_path, self._exports_to_keep)
        return export_dir
      else:
        logging.warn("LatestModuleExporter found zero modules to export. "
                     "Use hub.register_module_for_export() if needed.")
        # No export_dir has been created.
        return None 
Example #30
Source File: train_eval.py    From tensor2robot with Apache License 2.0 4 votes vote down vote up
def save_copy(src_filename,
              dest_filename,
              overwrite=False,
              num_retries=3,
              sleep_time=0.5):
  """Copy a file while catching errors and retrying a set amount of times.

  Args:
    src_filename: Source file name.
    dest_filename: Destination file name.
    overwrite: Whether to overwrite an existing destination.
    num_retries: Number of times to try again on failure.
    sleep_time: Time to sleep between trials in seconds.

  Returns:
    True is no errors occured, False otherwise.
  """
  if tf.io.gfile.exists(dest_filename):
    logging.warn('Could not copy file "%s" to "%s", because the destination'
                 ' already exists.', src_filename, dest_filename)
    return False

  for _ in range(num_retries):
    try:
      logging.info('Copying %s to %s', src_filename, dest_filename)
      tf.io.gfile.copy(src_filename, dest_filename, overwrite=overwrite)
    except tf.errors.NotFoundError:
      # This should not happen, but it does. tf.io.gfile.glob gave us a
      # filename, but somehow it is not yet available anyway.
      logging.warn('Copying %s to %s failed with NotFoundError',
                   src_filename, dest_filename)
      time.sleep(sleep_time)
    except tf.errors.AlreadyExistsError:
      # This should not happen, but it does. tf.io.gfile.exists said the
      # destination did not exist right before, but now it did.
      logging.warn('Copying %s to %s failed with AlreadyExistsError',
                   src_filename, dest_filename)
      return False
    else:
      break
  else:
    logging.warn(
        'Could not copy file "%s", because source file was not found. ',
        src_filename)
    return False

  return True