Python tensorflow.python.training.session_run_hook.SessionRunArgs() Examples

The following are 30 code examples of tensorflow.python.training.session_run_hook.SessionRunArgs(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.training.session_run_hook , or try the search function .
Example #1
Source File: basic_session_run_hooks.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    if self._timer.last_triggered_step() is None:
      # Write graph in the first call.
      training_util.write_graph(
          ops.get_default_graph().as_graph_def(add_shapes=True),
          self._checkpoint_dir,
          "graph.pbtxt")
      saver_def = self._saver.saver_def if self._saver else None
      graph = ops.get_default_graph()
      meta_graph_def = meta_graph.create_meta_graph_def(
          graph_def=graph.as_graph_def(add_shapes=True),
          saver_def=saver_def)
      self._summary_writer.add_graph(graph)
      self._summary_writer.add_meta_graph(meta_graph_def)

    return SessionRunArgs(self._global_step_tensor) 
Example #2
Source File: basic_session_run_hooks.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 6 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    if self._timer.last_triggered_step() is None:
      # We do write graph and saver_def at the first call of before_run.
      # We cannot do this in begin, since we let other hooks to change graph and
      # add variables in begin. Graph is finalized after all begin calls.
      training_util.write_graph(
          ops.get_default_graph().as_graph_def(add_shapes=True),
          self._checkpoint_dir,
          "graph.pbtxt")
      saver_def = self._get_saver().saver_def if self._get_saver() else None
      graph = ops.get_default_graph()
      meta_graph_def = meta_graph.create_meta_graph_def(
          graph_def=graph.as_graph_def(add_shapes=True),
          saver_def=saver_def)
      self._summary_writer.add_graph(graph)
      self._summary_writer.add_meta_graph(meta_graph_def)

    return SessionRunArgs(self._global_step_tensor) 
Example #3
Source File: basic_session_run_hooks.py    From lambda-packs with MIT License 6 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    if self._timer.last_triggered_step() is None:
      # We do write graph and saver_def at the first call of before_run.
      # We cannot do this in begin, since we let other hooks to change graph and
      # add variables in begin. Graph is finalized after all begin calls.
      training_util.write_graph(
          ops.get_default_graph().as_graph_def(add_shapes=True),
          self._checkpoint_dir,
          "graph.pbtxt")
      saver_def = self._get_saver().saver_def if self._get_saver() else None
      graph = ops.get_default_graph()
      meta_graph_def = meta_graph.create_meta_graph_def(
          graph_def=graph.as_graph_def(add_shapes=True),
          saver_def=saver_def)
      self._summary_writer.add_graph(graph)
      self._summary_writer.add_meta_graph(meta_graph_def)

    return SessionRunArgs(self._global_step_tensor) 
Example #4
Source File: hooks.py    From keras-lambda with MIT License 5 votes vote down vote up
def before_run(self, run_context):
    if not self._wrapper_initialized:
      local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
          self, run_context.session, ui_type=self._ui_type)
      self._wrapper_initialized = True

    # Increment run call counter.
    self._run_call_count += 1

    # Adapt run_context to an instance of OnRunStartRequest for invoking
    # superclass on_run_start().
    on_run_start_request = framework.OnRunStartRequest(
        run_context.original_args.fetches, run_context.original_args.feed_dict,
        None, None, self._run_call_count)

    on_run_start_response = self.on_run_start(on_run_start_request)
    self._performed_action = on_run_start_response.action

    run_args = session_run_hook.SessionRunArgs(
        None, feed_dict=None, options=config_pb2.RunOptions())
    if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
      self._decorate_options_for_debug(run_args.options,
                                       run_context.session.graph)
    elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
      # The _finalized property must be set to False so that the NodeStepper
      # can insert ops for retrieving TensorHandles.
      # pylint: disable=protected-access
      run_context.session.graph._finalized = False
      # pylint: enable=protected-access

      self.invoke_node_stepper(
          stepper.NodeStepper(run_context.session, run_context.original_args.
                              fetches, run_context.original_args.feed_dict),
          restore_variable_values_on_exit=True)

    return run_args 
Example #5
Source File: hooks.py    From ctc-asr with MIT License 5 votes vote down vote up
def before_run(self, run_context):
        # Asks for loss value and global step.
        return tf.train.SessionRunArgs(fetches=[self.loss_op, self._global_step_tensor]) 
Example #6
Source File: tpu_estimator.py    From xlnet with Apache License 2.0 5 votes vote down vote up
def after_run(self, run_context, run_values):
    # Global step cannot be retrieved via SessionRunArgs and before_run due to
    # race condition.
    global_step = run_context.session.run(self._global_step_tensor)
    if global_step >= self._last_step:
      run_context.request_stop()
    else:
      iterations = self._next_iterations(global_step, self._last_step)
      self._iterations_per_loop_var.load(
          iterations, session=run_context.session) 
Example #7
Source File: tpu_estimator.py    From xlnet with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):
    return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) 
Example #8
Source File: tpu_estimator.py    From xlnet with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):
    return basic_session_run_hooks.SessionRunArgs(self._tensors) 
Example #9
Source File: early_stopping.py    From estimator with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):
    del run_context
    return tf.compat.v1.train.SessionRunArgs(self._global_step_tensor) 
Example #10
Source File: basic_session_run_hooks.py    From keras-lambda with MIT License 5 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
    if self._should_trigger:
      return SessionRunArgs(self._current_tensors)
    else:
      return None 
Example #11
Source File: basic_session_run_hooks.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    return session_run_hook.SessionRunArgs(
        fetches=None, feed_dict=self.feed_fn()) 
Example #12
Source File: hooks.py    From ctc-asr with MIT License 5 votes vote down vote up
def before_run(self, run_context):
        """Is called once before each call to session.run (training iteration in general).

        At this point the graph is finalized and you can not add ops.

        Arguments:
            run_context (tf.train.SessionRunContext):
                The `run_context` argument is a `SessionRunContext` that provides
                information about the upcoming `run()` call: the originally requested
                op/tensors, the TensorFlow Session.
                SessionRunHook objects can stop the loop by calling `request_stop()` of
                `run_context`.
                Sadly you have to take a look at 'tensorflow/python/training/session_run_hook.py'
                for more details.
        Returns:
            tf.train.SessionRunArgs:
                None or a `SessionRunArgs` object.
                Represents arguments to be added to a `Session.run()` call.
                Sadly you have to take a look at 'tensorflow/python/training/session_run_hook.py'
                for more details.
        """
        # Request to read the global step tensor when running the hook.
        # The content of the requested tensors is passed to the hooks `after_run` function.
        fetches = [
            # This will deliver the global step as it was before the `session.run`
            # call was executed.
            self._global_step_tensor
        ]
        return session_run_hook.SessionRunArgs(fetches=fetches) 
Example #13
Source File: hooks.py    From keras-lambda with MIT License 5 votes vote down vote up
def before_run(self, run_context):
    if not self._wrapper_initialized:
      dumping_wrapper.DumpingDebugWrapperSession.__init__(
          self,
          run_context.session,
          self._session_root,
          watch_fn=self._watch_fn,
          log_usage=self._log_usage)
      self._wrapper_initialized = True

    self._run_call_count += 1

    (debug_urls, debug_ops, node_name_regex_whitelist,
     op_type_regex_whitelist) = self._prepare_run_watch_config(
         run_context.original_args.fetches, run_context.original_args.feed_dict)
    run_options = config_pb2.RunOptions()
    debug_utils.watch_graph(
        run_options,
        run_context.session.graph,
        debug_urls=debug_urls,
        debug_ops=debug_ops,
        node_name_regex_whitelist=node_name_regex_whitelist,
        op_type_regex_whitelist=op_type_regex_whitelist)

    run_args = session_run_hook.SessionRunArgs(
        None, feed_dict=None, options=run_options)
    return run_args 
Example #14
Source File: monitored_session.py    From keras-lambda with MIT License 5 votes vote down vote up
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
    """See base class."""
    if self.should_stop():
      raise RuntimeError('Run called even after should_stop requested.')

    actual_fetches = {'caller': fetches}

    run_context = session_run_hook.SessionRunContext(
        original_args=session_run_hook.SessionRunArgs(fetches, feed_dict),
        session=self._sess)

    options = options or config_pb2.RunOptions()
    feed_dict = self._call_hook_before_run(run_context, actual_fetches,
                                           feed_dict, options)

    # Do session run.
    run_metadata = run_metadata or config_pb2.RunMetadata()
    outputs = _WrappedSession.run(self,
                                  fetches=actual_fetches,
                                  feed_dict=feed_dict,
                                  options=options,
                                  run_metadata=run_metadata)

    for hook in self._hooks:
      hook.after_run(
          run_context,
          session_run_hook.SessionRunValues(
              results=outputs[hook] if hook in outputs else None,
              options=options,
              run_metadata=run_metadata))
    self._should_stop = self._should_stop or run_context.stop_requested

    return outputs['caller'] 
Example #15
Source File: basic_session_run_hooks.py    From keras-lambda with MIT License 5 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    return SessionRunArgs(self._global_step_tensor) 
Example #16
Source File: basic_session_run_hooks.py    From keras-lambda with MIT License 5 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    self._request_summary = (
        self._next_step is None or
        self._timer.should_trigger_for_step(self._next_step))
    requests = {"global_step": self._global_step_tensor}
    if self._request_summary:
      if self._get_summary_op() is not None:
        requests["summary"] = self._get_summary_op()

    return SessionRunArgs(requests) 
Example #17
Source File: basic_session_run_hooks.py    From keras-lambda with MIT License 5 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    return SessionRunArgs(self._loss_tensor) 
Example #18
Source File: hooks.py    From ctc-asr with MIT License 5 votes vote down vote up
def before_run(self, run_context):
        if self._trace:
            options = tf.RunOptions(trace_level=self.trace_level)
        else:
            options = None

        return tf.train.SessionRunArgs(fetches=self._global_step_tensor, options=options) 
Example #19
Source File: early_stopping.py    From estimator with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):
    del run_context
    return session_run_hook.SessionRunArgs({
        'global_step': self._global_step_tensor,
        'stop_var': self._stop_var
    }) 
Example #20
Source File: tpu_estimator.py    From transformer-xl with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):
    return basic_session_run_hooks.SessionRunArgs(self._tensors) 
Example #21
Source File: tpu_estimator.py    From transformer-xl with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):
    return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) 
Example #22
Source File: tpu_estimator.py    From transformer-xl with Apache License 2.0 5 votes vote down vote up
def after_run(self, run_context, run_values):
    # Global step cannot be retrieved via SessionRunArgs and before_run due to
    # race condition.
    global_step = run_context.session.run(self._global_step_tensor)
    if global_step >= self._last_step:
      run_context.request_stop()
    else:
      iterations = self._next_iterations(global_step, self._last_step)
      self._iterations_per_loop_var.load(
          iterations, session=run_context.session) 
Example #23
Source File: tpu_estimator.py    From embedding-as-service with MIT License 5 votes vote down vote up
def before_run(self, run_context):
    return basic_session_run_hooks.SessionRunArgs(self._tensors) 
Example #24
Source File: tpu_estimator.py    From embedding-as-service with MIT License 5 votes vote down vote up
def before_run(self, run_context):
    return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) 
Example #25
Source File: tpu_estimator.py    From embedding-as-service with MIT License 5 votes vote down vote up
def after_run(self, run_context, run_values):
    # Global step cannot be retrieved via SessionRunArgs and before_run due to
    # race condition.
    global_step = run_context.session.run(self._global_step_tensor)
    if global_step >= self._last_step:
      run_context.request_stop()
    else:
      iterations = self._next_iterations(global_step, self._last_step)
      self._iterations_per_loop_var.load(
          iterations, session=run_context.session) 
Example #26
Source File: random_forest.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):
    return session_run_hook.SessionRunArgs(
        {'global_step': contrib_framework.get_global_step(),
         'current_loss': run_context.session.graph.get_operation_by_name(
             LOSS_NAME).outputs[0]}) 
Example #27
Source File: linear.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):
    """Return the update_weights op so that it is executed during this run."""
    return session_run_hook.SessionRunArgs(self._update_op) 
Example #28
Source File: basic_session_run_hooks.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    self._request_summary = (
        self._next_step is None or
        self._timer.should_trigger_for_step(self._next_step))
    requests = {"global_step": self._global_step_tensor}
    if self._request_summary:
      if self._summary_op is not None:
        requests["summary"] = self._summary_op
      elif self._scaffold.summary_op is not None:
        requests["summary"] = self._scaffold.summary_op

    return SessionRunArgs(requests) 
Example #29
Source File: basic_session_run_hooks.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    return SessionRunArgs(self._global_step_tensor) 
Example #30
Source File: basic_session_run_hooks.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def before_run(self, run_context):  # pylint: disable=unused-argument
    return SessionRunArgs(self._global_step_tensor)