Python tensorflow.python.training.session_run_hook.SessionRunHook() Examples
The following are 22
code examples of tensorflow.python.training.session_run_hook.SessionRunHook().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.training.session_run_hook
, or try the search function
.
Example #1
Source File: tpu_estimator.py From Chinese-XLNet with Apache License 2.0 | 6 votes |
def after_run(self, run_context, run_values): _ = run_context scalar_stopping_signal = run_values.results if _StopSignals.should_stop(scalar_stopping_signal): # NOTE(xiejw): In prediction, stopping signals are inserted for each # batch. And we append one more batch to signal the system it should stop. # The data flow might look like # # batch 0: images, labels, stop = 0 (user provided) # batch 1: images, labels, stop = 0 (user provided) # ... # batch 99: images, labels, stop = 0 (user provided) # batch 100: images, labels, stop = 1 (TPUEstimator appended) # # where the final batch (id = 100) is appended by TPUEstimator, so we # should drop it before returning the predictions to user. # To achieve that, we throw the OutOfRangeError in after_run. Once # Monitored Session sees this error in SessionRunHook.after_run, the # "current" prediction, i.e., batch with id=100, will be discarded # immediately raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')
Example #2
Source File: linear.py From deep_image_model with Apache License 2.0 | 6 votes |
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None): """See trainable.Trainable.""" # TODO(roumposg): Remove when deprecated monitors are removed. if monitors is None: monitors = [] deprecated_monitors = [ m for m in monitors if not isinstance(m, session_run_hook.SessionRunHook) ] for monitor in deprecated_monitors: monitor.set_estimator(self) monitor._lock_estimator() # pylint: disable=protected-access if self._additional_run_hook: monitors.append(self._additional_run_hook) result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps, batch_size=batch_size, monitors=monitors, max_steps=max_steps) for monitor in deprecated_monitors: monitor._unlock_estimator() # pylint: disable=protected-access return result
Example #3
Source File: linear.py From deep_image_model with Apache License 2.0 | 6 votes |
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None): """See trainable.Trainable.""" # TODO(roumposg): Remove when deprecated monitors are removed. if monitors is None: monitors = [] deprecated_monitors = [ m for m in monitors if not isinstance(m, session_run_hook.SessionRunHook) ] for monitor in deprecated_monitors: monitor.set_estimator(self) monitor._lock_estimator() # pylint: disable=protected-access if self._additional_run_hook: monitors.append(self._additional_run_hook) result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps, batch_size=batch_size, monitors=monitors, max_steps=max_steps) for monitor in deprecated_monitors: monitor._unlock_estimator() # pylint: disable=protected-access return result
Example #4
Source File: tpu_estimator.py From xlnet with Apache License 2.0 | 6 votes |
def after_run(self, run_context, run_values): _ = run_context scalar_stopping_signal = run_values.results if _StopSignals.should_stop(scalar_stopping_signal): # NOTE(xiejw): In prediction, stopping signals are inserted for each # batch. And we append one more batch to signal the system it should stop. # The data flow might look like # # batch 0: images, labels, stop = 0 (user provided) # batch 1: images, labels, stop = 0 (user provided) # ... # batch 99: images, labels, stop = 0 (user provided) # batch 100: images, labels, stop = 1 (TPUEstimator appended) # # where the final batch (id = 100) is appended by TPUEstimator, so we # should drop it before returning the predictions to user. # To achieve that, we throw the OutOfRangeError in after_run. Once # Monitored Session sees this error in SessionRunHook.after_run, the # "current" prediction, i.e., batch with id=100, will be discarded # immediately raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')
Example #5
Source File: tpu_estimator.py From embedding-as-service with MIT License | 6 votes |
def after_run(self, run_context, run_values): _ = run_context scalar_stopping_signal = run_values.results if _StopSignals.should_stop(scalar_stopping_signal): # NOTE(xiejw): In prediction, stopping signals are inserted for each # batch. And we append one more batch to signal the system it should stop. # The data flow might look like # # batch 0: images, labels, stop = 0 (user provided) # batch 1: images, labels, stop = 0 (user provided) # ... # batch 99: images, labels, stop = 0 (user provided) # batch 100: images, labels, stop = 1 (TPUEstimator appended) # # where the final batch (id = 100) is appended by TPUEstimator, so we # should drop it before returning the predictions to user. # To achieve that, we throw the OutOfRangeError in after_run. Once # Monitored Session sees this error in SessionRunHook.after_run, the # "current" prediction, i.e., batch with id=100, will be discarded # immediately raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')
Example #6
Source File: tpu_estimator.py From transformer-xl with Apache License 2.0 | 6 votes |
def after_run(self, run_context, run_values): _ = run_context scalar_stopping_signal = run_values.results if _StopSignals.should_stop(scalar_stopping_signal): # NOTE(xiejw): In prediction, stopping signals are inserted for each # batch. And we append one more batch to signal the system it should stop. # The data flow might look like # # batch 0: images, labels, stop = 0 (user provided) # batch 1: images, labels, stop = 0 (user provided) # ... # batch 99: images, labels, stop = 0 (user provided) # batch 100: images, labels, stop = 1 (TPUEstimator appended) # # where the final batch (id = 100) is appended by TPUEstimator, so we # should drop it before returning the predictions to user. # To achieve that, we throw the OutOfRangeError in after_run. Once # Monitored Session sees this error in SessionRunHook.after_run, the # "current" prediction, i.e., batch with id=100, will be discarded # immediately raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')
Example #7
Source File: training.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 5 votes |
def __new__(cls, input_fn, max_steps=None, hooks=None): """Creates a validated `TrainSpec` instance. Args: input_fn: Training input function returning a tuple of: features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. max_steps: Int. Positive number of total steps for which to train model. If `None`, train forever. The training `input_fn` is not expected to generate `OutOfRangeError` or `StopIteration` exceptions. See the `train_and_evaluate` stop condition section for details. hooks: Iterable of `tf.train.SessionRunHook` objects to run on all workers (including chief) during training. Returns: A validated `TrainSpec` object. Raises: ValueError: If any of the input arguments is invalid. TypeError: If any of the arguments is not of the expected type. """ # Validate input_fn. _validate_input_fn(input_fn) # Validate max_steps. if max_steps is not None and max_steps <= 0: raise ValueError( 'Must specify max_steps > 0, given: {}'.format(max_steps)) # Validate hooks. hooks = _validate_hooks(hooks) return super(TrainSpec, cls).__new__( cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
Example #8
Source File: training.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 5 votes |
def _validate_hooks(hooks): """Validates the `hooks`.""" hooks = tuple(hooks or []) for hook in hooks: if not isinstance(hook, session_run_hook.SessionRunHook): raise TypeError( 'All hooks must be `SessionRunHook` instances, given: {}'.format( hook)) return hooks
Example #9
Source File: tpu_estimator.py From xlnet with Apache License 2.0 | 5 votes |
def should_stop(scalar_stopping_signal): """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. return math_ops.logical_and(scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL) else: # For non Tensor case, it is used in SessionRunHook. So, we cannot modify # the graph anymore. Here, we use pure Python. return bool(scalar_stopping_signal)
Example #10
Source File: tpu_estimator.py From transformer-xl with Apache License 2.0 | 5 votes |
def should_stop(scalar_stopping_signal): """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. return math_ops.logical_and( scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL) else: # For non Tensor case, it is used in SessionRunHook. So, we cannot modify # the graph anymore. Here, we use pure Python. return bool(scalar_stopping_signal)
Example #11
Source File: tpu_estimator.py From transformer-xl with Apache License 2.0 | 5 votes |
def dataset_initializer_hook(self): """Returns a `SessionRunHook` to initialize this dataset. This must be called before `features_and_labels`. """ iterator = self._dataset.make_initializable_iterator() # pylint: disable=protected-access hook = estimator_util._DatasetInitializerHook(iterator) # pylint: enable=protected-access self._iterator = iterator return hook
Example #12
Source File: tpu_estimator.py From embedding-as-service with MIT License | 5 votes |
def should_stop(scalar_stopping_signal): """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. return math_ops.logical_and(scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL) else: # For non Tensor case, it is used in SessionRunHook. So, we cannot modify # the graph anymore. Here, we use pure Python. return bool(scalar_stopping_signal)
Example #13
Source File: tpu_estimator.py From Chinese-XLNet with Apache License 2.0 | 5 votes |
def should_stop(scalar_stopping_signal): """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. return math_ops.logical_and(scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL) else: # For non Tensor case, it is used in SessionRunHook. So, we cannot modify # the graph anymore. Here, we use pure Python. return bool(scalar_stopping_signal)
Example #14
Source File: tpu_estimator.py From transformer-xl with Apache License 2.0 | 4 votes |
def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metrics=None, export_outputs=None, scaffold_fn=None, host_call=None, training_hooks=None, evaluation_hooks=None, prediction_hooks=None): """Creates a validated `TPUEstimatorSpec` instance.""" host_calls = {} if eval_metrics is not None: host_calls['eval_metrics'] = eval_metrics if host_call is not None: host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) training_hooks = list(training_hooks or []) evaluation_hooks = list(evaluation_hooks or []) prediction_hooks = list(prediction_hooks or []) for hook in training_hooks + evaluation_hooks + prediction_hooks: if not isinstance(hook, session_run_hook.SessionRunHook): raise TypeError( 'All hooks must be SessionRunHook instances, given: {}'.format( hook)) return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metrics=eval_metrics, export_outputs=export_outputs, scaffold_fn=scaffold_fn, host_call=host_call, training_hooks=training_hooks, evaluation_hooks=evaluation_hooks, prediction_hooks=prediction_hooks)
Example #15
Source File: tpu_estimator.py From xlnet with Apache License 2.0 | 4 votes |
def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metrics=None, export_outputs=None, scaffold_fn=None, host_call=None, training_hooks=None, evaluation_hooks=None, prediction_hooks=None): """Creates a validated `TPUEstimatorSpec` instance.""" host_calls = {} if eval_metrics is not None: host_calls['eval_metrics'] = eval_metrics if host_call is not None: host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) training_hooks = tuple(training_hooks or []) evaluation_hooks = tuple(evaluation_hooks or []) prediction_hooks = tuple(prediction_hooks or []) for hook in training_hooks + evaluation_hooks + prediction_hooks: if not isinstance(hook, session_run_hook.SessionRunHook): raise TypeError('All hooks must be SessionRunHook instances, given: {}' .format(hook)) return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metrics=eval_metrics, export_outputs=export_outputs, scaffold_fn=scaffold_fn, host_call=host_call, training_hooks=training_hooks, evaluation_hooks=evaluation_hooks, prediction_hooks=prediction_hooks)
Example #16
Source File: tpu_estimator.py From embedding-as-service with MIT License | 4 votes |
def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metrics=None, export_outputs=None, scaffold_fn=None, host_call=None, training_hooks=None, evaluation_hooks=None, prediction_hooks=None): """Creates a validated `TPUEstimatorSpec` instance.""" host_calls = {} if eval_metrics is not None: host_calls['eval_metrics'] = eval_metrics if host_call is not None: host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) training_hooks = tuple(training_hooks or []) evaluation_hooks = tuple(evaluation_hooks or []) prediction_hooks = tuple(prediction_hooks or []) for hook in training_hooks + evaluation_hooks + prediction_hooks: if not isinstance(hook, session_run_hook.SessionRunHook): raise TypeError('All hooks must be SessionRunHook instances, given: {}' .format(hook)) return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metrics=eval_metrics, export_outputs=export_outputs, scaffold_fn=scaffold_fn, host_call=host_call, training_hooks=training_hooks, evaluation_hooks=evaluation_hooks, prediction_hooks=prediction_hooks)
Example #17
Source File: evaluation_test.py From tf-slim with Apache License 2.0 | 4 votes |
def testFinalOpsOnEvaluationLoop(self): value_op, update_op = metrics.accuracy( labels=self._labels, predictions=self._predictions) init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) # Create checkpoint and log directories: chkpt_dir = tempfile.mkdtemp('tmp_logs') logdir = tempfile.mkdtemp('tmp_logs2') # Save initialized variables to a checkpoint directory: saver = saver_lib.Saver() with self.cached_session() as sess: init_op.run() saver.save(sess, os.path.join(chkpt_dir, 'chkpt')) class Object(object): def __init__(self): self.hook_was_run = False obj = Object() # Create a custom session run hook. class CustomHook(session_run_hook.SessionRunHook): def __init__(self, obj): self.obj = obj def end(self, session): self.obj.hook_was_run = True # Now, run the evaluation loop: accuracy_value = evaluation.evaluation_loop( '', chkpt_dir, logdir, eval_op=update_op, final_op=value_op, hooks=[CustomHook(obj)], max_number_of_evaluations=1) self.assertAlmostEqual(accuracy_value, self._expected_accuracy) # Validate that custom hook ran. self.assertTrue(obj.hook_was_run)
Example #18
Source File: early_stopping.py From estimator with Apache License 2.0 | 4 votes |
def stop_if_higher_hook(estimator, metric_name, threshold, eval_dir=None, min_steps=0, run_every_secs=60, run_every_steps=None): """Creates hook to stop if the given metric is higher than the threshold. Usage example: ```python estimator = ... # Hook to stop training if accuracy becomes higher than 0.9. hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9) train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) tf.estimator.train_and_evaluate(estimator, train_spec, ...) ``` Caveat: Current implementation supports early-stopping both training and evaluation in local mode. In distributed mode, training can be stopped but evaluation (where it's a separate job) will indefinitely wait for new model checkpoints to evaluate, so you will need other means to detect and stop it. Early-stopping evaluation in distributed mode requires changes in `train_and_evaluate` API and will be addressed in a future revision. Args: estimator: A `tf.estimator.Estimator` instance. metric_name: `str`, metric to track. "loss", "accuracy", etc. threshold: Numeric threshold for the given metric. eval_dir: If set, directory containing summary files with eval metrics. By default, `estimator.eval_dir()` will be used. min_steps: `int`, stop is never requested if global step is less than this value. Defaults to 0. run_every_secs: If specified, calls `should_stop_fn` at an interval of `run_every_secs` seconds. Defaults to 60 seconds. Either this or `run_every_steps` must be set. run_every_steps: If specified, calls `should_stop_fn` every `run_every_steps` steps. Either this or `run_every_secs` must be set. Returns: An early-stopping hook of type `SessionRunHook` that periodically checks if the given metric is higher than specified threshold and initiates early stopping if true. """ return _stop_if_threshold_crossed_hook( estimator=estimator, metric_name=metric_name, threshold=threshold, higher_is_better=True, eval_dir=eval_dir, min_steps=min_steps, run_every_secs=run_every_secs, run_every_steps=run_every_steps)
Example #19
Source File: early_stopping.py From estimator with Apache License 2.0 | 4 votes |
def stop_if_lower_hook(estimator, metric_name, threshold, eval_dir=None, min_steps=0, run_every_secs=60, run_every_steps=None): """Creates hook to stop if the given metric is lower than the threshold. Usage example: ```python estimator = ... # Hook to stop training if loss becomes lower than 100. hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100) train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) tf.estimator.train_and_evaluate(estimator, train_spec, ...) ``` Caveat: Current implementation supports early-stopping both training and evaluation in local mode. In distributed mode, training can be stopped but evaluation (where it's a separate job) will indefinitely wait for new model checkpoints to evaluate, so you will need other means to detect and stop it. Early-stopping evaluation in distributed mode requires changes in `train_and_evaluate` API and will be addressed in a future revision. Args: estimator: A `tf.estimator.Estimator` instance. metric_name: `str`, metric to track. "loss", "accuracy", etc. threshold: Numeric threshold for the given metric. eval_dir: If set, directory containing summary files with eval metrics. By default, `estimator.eval_dir()` will be used. min_steps: `int`, stop is never requested if global step is less than this value. Defaults to 0. run_every_secs: If specified, calls `should_stop_fn` at an interval of `run_every_secs` seconds. Defaults to 60 seconds. Either this or `run_every_steps` must be set. run_every_steps: If specified, calls `should_stop_fn` every `run_every_steps` steps. Either this or `run_every_secs` must be set. Returns: An early-stopping hook of type `SessionRunHook` that periodically checks if the given metric is lower than specified threshold and initiates early stopping if true. """ return _stop_if_threshold_crossed_hook( estimator=estimator, metric_name=metric_name, threshold=threshold, higher_is_better=False, eval_dir=eval_dir, min_steps=min_steps, run_every_secs=run_every_secs, run_every_steps=run_every_steps)
Example #20
Source File: early_stopping.py From estimator with Apache License 2.0 | 4 votes |
def stop_if_no_increase_hook(estimator, metric_name, max_steps_without_increase, eval_dir=None, min_steps=0, run_every_secs=60, run_every_steps=None): """Creates hook to stop if metric does not increase within given max steps. Usage example: ```python estimator = ... # Hook to stop training if accuracy does not increase in over 100000 steps. hook = early_stopping.stop_if_no_increase_hook(estimator, "accuracy", 100000) train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) tf.estimator.train_and_evaluate(estimator, train_spec, ...) ``` Caveat: Current implementation supports early-stopping both training and evaluation in local mode. In distributed mode, training can be stopped but evaluation (where it's a separate job) will indefinitely wait for new model checkpoints to evaluate, so you will need other means to detect and stop it. Early-stopping evaluation in distributed mode requires changes in `train_and_evaluate` API and will be addressed in a future revision. Args: estimator: A `tf.estimator.Estimator` instance. metric_name: `str`, metric to track. "loss", "accuracy", etc. max_steps_without_increase: `int`, maximum number of training steps with no increase in the given metric. eval_dir: If set, directory containing summary files with eval metrics. By default, `estimator.eval_dir()` will be used. min_steps: `int`, stop is never requested if global step is less than this value. Defaults to 0. run_every_secs: If specified, calls `should_stop_fn` at an interval of `run_every_secs` seconds. Defaults to 60 seconds. Either this or `run_every_steps` must be set. run_every_steps: If specified, calls `should_stop_fn` every `run_every_steps` steps. Either this or `run_every_secs` must be set. Returns: An early-stopping hook of type `SessionRunHook` that periodically checks if the given metric shows no increase over given maximum number of training steps, and initiates early stopping if true. """ return _stop_if_no_metric_improvement_hook( estimator=estimator, metric_name=metric_name, max_steps_without_improvement=max_steps_without_increase, higher_is_better=True, eval_dir=eval_dir, min_steps=min_steps, run_every_secs=run_every_secs, run_every_steps=run_every_steps)
Example #21
Source File: early_stopping.py From estimator with Apache License 2.0 | 4 votes |
def stop_if_no_decrease_hook(estimator, metric_name, max_steps_without_decrease, eval_dir=None, min_steps=0, run_every_secs=60, run_every_steps=None): """Creates hook to stop if metric does not decrease within given max steps. Usage example: ```python estimator = ... # Hook to stop training if loss does not decrease in over 100000 steps. hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000) train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) tf.estimator.train_and_evaluate(estimator, train_spec, ...) ``` Caveat: Current implementation supports early-stopping both training and evaluation in local mode. In distributed mode, training can be stopped but evaluation (where it's a separate job) will indefinitely wait for new model checkpoints to evaluate, so you will need other means to detect and stop it. Early-stopping evaluation in distributed mode requires changes in `train_and_evaluate` API and will be addressed in a future revision. Args: estimator: A `tf.estimator.Estimator` instance. metric_name: `str`, metric to track. "loss", "accuracy", etc. max_steps_without_decrease: `int`, maximum number of training steps with no decrease in the given metric. eval_dir: If set, directory containing summary files with eval metrics. By default, `estimator.eval_dir()` will be used. min_steps: `int`, stop is never requested if global step is less than this value. Defaults to 0. run_every_secs: If specified, calls `should_stop_fn` at an interval of `run_every_secs` seconds. Defaults to 60 seconds. Either this or `run_every_steps` must be set. run_every_steps: If specified, calls `should_stop_fn` every `run_every_steps` steps. Either this or `run_every_secs` must be set. Returns: An early-stopping hook of type `SessionRunHook` that periodically checks if the given metric shows no decrease over given maximum number of training steps, and initiates early stopping if true. """ return _stop_if_no_metric_improvement_hook( estimator=estimator, metric_name=metric_name, max_steps_without_improvement=max_steps_without_decrease, higher_is_better=False, eval_dir=eval_dir, min_steps=min_steps, run_every_secs=run_every_secs, run_every_steps=run_every_steps)
Example #22
Source File: tpu_estimator.py From Chinese-XLNet with Apache License 2.0 | 4 votes |
def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metrics=None, export_outputs=None, scaffold_fn=None, host_call=None, training_hooks=None, evaluation_hooks=None, prediction_hooks=None): """Creates a validated `TPUEstimatorSpec` instance.""" host_calls = {} if eval_metrics is not None: host_calls['eval_metrics'] = eval_metrics if host_call is not None: host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) training_hooks = tuple(training_hooks or []) evaluation_hooks = tuple(evaluation_hooks or []) prediction_hooks = tuple(prediction_hooks or []) for hook in training_hooks + evaluation_hooks + prediction_hooks: if not isinstance(hook, session_run_hook.SessionRunHook): raise TypeError('All hooks must be SessionRunHook instances, given: {}' .format(hook)) return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metrics=eval_metrics, export_outputs=export_outputs, scaffold_fn=scaffold_fn, host_call=host_call, training_hooks=training_hooks, evaluation_hooks=evaluation_hooks, prediction_hooks=prediction_hooks)