Python absl.testing.parameterized.TestCase() Examples

The following are 30 code examples of absl.testing.parameterized.TestCase(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module absl.testing.parameterized , or try the search function .
Example #1
Source File: dnn_test_fc_v1_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest.__init__(
        self, _dnn_regressor_fn, fc_impl=feature_column) 
Example #2
Source File: dnn_test_fc_v2.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils.BaseDNNModelFnTest.__init__(
        self, dnn.dnn_model_fn_v2, fc_impl=feature_column_v2) 
Example #3
Source File: dnn_test_fc_v2.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils.BaseDNNLogitFnTest.__init__(
        self, dnn.dnn_logit_fn_builder_v2, fc_impl=feature_column_v2) 
Example #4
Source File: dnn_test_fc_v2.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils.BaseDNNWarmStartingTest.__init__(
        self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column_v2) 
Example #5
Source File: dnn_test_fc_v2.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
        self, _dnn_classifier_fn, fc_impl=feature_column_v2) 
Example #6
Source File: dnn_test_fc_v2.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
        self, _dnn_classifier_fn, fc_impl=feature_column_v2) 
Example #7
Source File: dnn_test_fc_v2.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
        self, _dnn_regressor_fn, fc_impl=feature_column_v2) 
Example #8
Source File: dnn_test_fc_v2.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
        self, _dnn_regressor_fn, fc_impl=feature_column_v2) 
Example #9
Source File: dnn_test_fc_v2.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
        self, _dnn_regressor_fn, fc_impl=feature_column_v2) 
Example #10
Source File: dnn_test_fc_v1_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNModelFnTest.__init__(
        self, dnn._dnn_model_fn, fc_impl=feature_column) 
Example #11
Source File: dnn_test_fc_v1_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNLogitFnTest.__init__(
        self, dnn.dnn_logit_fn_builder, fc_impl=feature_column) 
Example #12
Source File: dnn_test_fc_v1_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest.__init__(
        self, _dnn_classifier_fn, fc_impl=feature_column) 
Example #13
Source File: dnn_test_fc_v1_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNClassifierPredictTest.__init__(
        self, _dnn_classifier_fn, fc_impl=feature_column) 
Example #14
Source File: dnn_test_fc_v1_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNClassifierTrainTest.__init__(
        self, _dnn_classifier_fn, fc_impl=feature_column) 
Example #15
Source File: loss_layers_test.py    From models with Apache License 2.0 5 votes vote down vote up
def run_lagrange_multiplier_test(global_objective,
                                 objective_kwargs,
                                 data_builder,
                                 test_object):
  """Runs a test for the Lagrange multiplier update of `global_objective`.

  The test checks that the constraint for `global_objective` is satisfied on
  the first label of the data produced by `data_builder` but not the second.

  Args:
    global_objective: One of the global objectives.
    objective_kwargs: A dictionary of keyword arguments to pass to
      `global_objective`. Must contain an entry for the constraint argument
      of `global_objective`, e.g. 'target_rate' or 'target_precision'.
    data_builder: A function  which returns tensors corresponding to labels,
      logits, and label priors.
    test_object: An instance of tf.test.TestCase.
  """
  # Construct global objective kwargs from a copy of `objective_kwargs`.
  kwargs = dict(objective_kwargs)
  targets, logits, priors = data_builder()
  kwargs['labels'] = targets
  kwargs['logits'] = logits
  kwargs['label_priors'] = priors

  loss, output_dict = global_objective(**kwargs)
  lambdas = tf.squeeze(output_dict['lambdas'])
  opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)
  update_op = opt.minimize(loss, var_list=[output_dict['lambdas']])

  with test_object.test_session() as session:
    tf.global_variables_initializer().run()
    lambdas_before = session.run(lambdas)
    session.run(update_op)
    lambdas_after = session.run(lambdas)
    test_object.assertLess(lambdas_after[0], lambdas_before[0])
    test_object.assertGreater(lambdas_after[1], lambdas_before[1]) 
Example #16
Source File: dnn_test_fc_v1_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNRegressorPredictTest.__init__(
        self, _dnn_regressor_fn, fc_impl=feature_column) 
Example #17
Source File: dnn_test_fc_v2_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNModelFnTest.__init__(
        self, dnn._dnn_model_fn, fc_impl=feature_column_v2) 
Example #18
Source File: dnn_test_fc_v2_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNLogitFnTest.__init__(
        self, dnn.dnn_logit_fn_builder, fc_impl=feature_column_v2) 
Example #19
Source File: dnn_test_fc_v2_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNWarmStartingTest.__init__(
        self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column_v2) 
Example #20
Source File: dnn_test_fc_v2_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest.__init__(
        self, _dnn_classifier_fn, fc_impl=feature_column_v2) 
Example #21
Source File: dnn_test_fc_v2_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNClassifierPredictTest.__init__(
        self, _dnn_classifier_fn, fc_impl=feature_column_v2) 
Example #22
Source File: dnn_test_fc_v2_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest.__init__(
        self, _dnn_regressor_fn, fc_impl=feature_column_v2) 
Example #23
Source File: dnn_test_fc_v2_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNRegressorPredictTest.__init__(
        self, _dnn_regressor_fn, fc_impl=feature_column_v2) 
Example #24
Source File: dnn_test_fc_v2_v1.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    dnn_testing_utils_v1.BaseDNNRegressorTrainTest.__init__(
        self, _dnn_regressor_fn, fc_impl=feature_column_v2) 
Example #25
Source File: rnn_test.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    BaseRNNClassificationIntegrationTest.__init__(self, _rnn_classifier_fn) 
Example #26
Source File: rnn_test.py    From estimator with Apache License 2.0 5 votes vote down vote up
def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    tf.test.TestCase.__init__(self, methodName)
    BaseRNNClassificationIntegrationTest.__init__(self,
                                                  _rnn_classifier_dropout_fn) 
Example #27
Source File: loss_layers_test.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def run_lagrange_multiplier_test(global_objective,
                                 objective_kwargs,
                                 data_builder,
                                 test_object):
  """Runs a test for the Lagrange multiplier update of `global_objective`.

  The test checks that the constraint for `global_objective` is satisfied on
  the first label of the data produced by `data_builder` but not the second.

  Args:
    global_objective: One of the global objectives.
    objective_kwargs: A dictionary of keyword arguments to pass to
      `global_objective`. Must contain an entry for the constraint argument
      of `global_objective`, e.g. 'target_rate' or 'target_precision'.
    data_builder: A function  which returns tensors corresponding to labels,
      logits, and label priors.
    test_object: An instance of tf.test.TestCase.
  """
  # Construct global objective kwargs from a copy of `objective_kwargs`.
  kwargs = dict(objective_kwargs)
  targets, logits, priors = data_builder()
  kwargs['labels'] = targets
  kwargs['logits'] = logits
  kwargs['label_priors'] = priors

  loss, output_dict = global_objective(**kwargs)
  lambdas = tf.squeeze(output_dict['lambdas'])
  opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)
  update_op = opt.minimize(loss, var_list=[output_dict['lambdas']])

  with test_object.test_session() as session:
    tf.global_variables_initializer().run()
    lambdas_before = session.run(lambdas)
    session.run(update_op)
    lambdas_after = session.run(lambdas)
    test_object.assertLess(lambdas_after[0], lambdas_before[0])
    test_object.assertGreater(lambdas_after[1], lambdas_before[1]) 
Example #28
Source File: parameterized_test.py    From abseil-py with Apache License 2.0 5 votes vote down vote up
def test_duplicate_dict_named_test_fails(self):
    with self.assertRaises(parameterized.DuplicateTestNameError):

      class _(parameterized.TestCase):

        @parameterized.named_parameters(
            {'testcase_name': 'Interesting', 'unused_obj': 0},
            {'testcase_name': 'Interesting', 'unused_obj': 1},
        )
        def test_dict_something(self, unused_obj):
          pass 
Example #29
Source File: quantiles_util_test.py    From data-validation with Apache License 2.0 5 votes vote down vote up
def _run_quantiles_combiner_test(test: absltest.TestCase,
                                 q_combiner: quantiles_util.QuantilesCombiner,
                                 batches: List[List[np.ndarray]],
                                 expected_result: np.ndarray):
  """Tests quantiles combiner."""
  summaries = [q_combiner.add_input(q_combiner.create_accumulator(),
                                    batch) for batch in batches]
  result = q_combiner.extract_output(q_combiner.merge_accumulators(summaries))
  test.assertEqual(result.dtype, expected_result.dtype)
  test.assertEqual(result.size, expected_result.size)
  for i in range(expected_result.size):
    test.assertAlmostEqual(result[i], expected_result[i]) 
Example #30
Source File: quantiles_util_test.py    From data-validation with Apache License 2.0 5 votes vote down vote up
def _assert_buckets_almost_equal(test: parameterized.TestCase,
                                 a: List[Tuple[float, float, float]],
                                 b: List[Tuple[float, float, float]]):
  """Check if the histogram buckets are almost equal."""
  test.assertEqual(len(a), len(b))
  for i in range(len(a)):
    test.assertAlmostEqual(a[i].low_value, b[i].low_value)
    test.assertAlmostEqual(a[i].high_value, b[i].high_value)
    test.assertAlmostEqual(a[i].sample_count, b[i].sample_count)