Python absl.testing.parameterized.parameters() Examples

The following are 30 code examples of absl.testing.parameterized.parameters(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module absl.testing.parameterized , or try the search function .
Example #1
Source File: ssgan_test.py    From compare_gan with Apache License 2.0 7 votes vote down vote up
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")
    gan = SSGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=model_dir,
        g_optimizer_fn=tf.train.AdamOptimizer,
        g_lr=0.0002,
        rotated_batch_size=4)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
Example #2
Source File: parameterized.py    From abseil-py with Apache License 2.0 6 votes vote down vote up
def parameters(*testcases):
  """A decorator for creating parameterized tests.

  See the module docstring for a usage example.

  Args:
    *testcases: Parameters for the decorated method, either a single
        iterable, or a list of tuples/dicts/objects (for tests with only one
        argument).

  Raises:
    NoTestsError: Raised when the decorator generates no tests.

  Returns:
     A test generator to be handled by TestGeneratorMetaclass.
  """
  return _parameter_decorator(_ARGUMENT_REPR, testcases) 
Example #3
Source File: test_utils_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def test_basic_encode_decode_tf_constructor_parameters(self):
    """Tests the core funcionality with `tf.Variable` constructor parameters."""
    a_var = tf.compat.v1.get_variable('a_var', initializer=self._DEFAULT_A)
    b_var = tf.compat.v1.get_variable('b_var', initializer=self._DEFAULT_B)
    stage = test_utils.SimpleLinearEncodingStage(a_var, b_var)

    with self.cached_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
    x = self.default_input()
    encode_params, decode_params = stage.get_params()
    encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                decode_params)
    test_data = self.evaluate_test_data(
        test_utils.TestData(x, encoded_x, decoded_x))
    self.common_asserts_for_test_data(test_data)

    # Change the variables and verify the behavior of stage changes.
    self.evaluate(
        [tf.compat.v1.assign(a_var, 5.0),
         tf.compat.v1.assign(b_var, 6.0)])
    test_data = self.evaluate_test_data(
        test_utils.TestData(x, encoded_x, decoded_x))
    self.assertAllClose(test_data.x * 5.0 + 6.0,
                        test_data.encoded_x[self._ENCODED_VALUES_KEY]) 
Example #4
Source File: modular_gan_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir,
        conditional="biggan" in architecture)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
Example #5
Source File: parameterized.py    From abseil-py with Apache License 2.0 6 votes vote down vote up
def _modify_class(class_object, testcases, naming_type):
  assert not getattr(class_object, '_test_method_ids', None), (
      'Cannot add parameters to %s. Either it already has parameterized '
      'methods, or its super class is also a parameterized class.' % (
          class_object,))
  class_object._test_method_ids = test_method_ids = {}
  for name, obj in six.iteritems(class_object.__dict__.copy()):
    if (name.startswith(unittest.TestLoader.testMethodPrefix)
        and isinstance(obj, types.FunctionType)):
      delattr(class_object, name)
      methods = {}
      _update_class_dict_for_param_test_case(
          class_object.__name__, methods, test_method_ids, name,
          _ParameterizedTestIter(obj, testcases, naming_type, name))
      for name, meth in six.iteritems(methods):
        setattr(class_object, name, meth) 
Example #6
Source File: modular_gan_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def testSingleTrainingStepWithJointGenForDisc(self):
    parameters = {
        "architecture": c.DUMMY_ARCH,
        "lambda": 1,
        "z_dim": 120,
        "disc_iters": 2,
    }
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir,
        experimental_joint_gen_for_disc=True,
        experimental_force_graph_unroll=True,
        conditional=True)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
Example #7
Source File: modular_gan_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def testSingleTrainingStepDiscItersWithEma(self, disc_iters):
    parameters = {
        "architecture": c.DUMMY_ARCH,
        "lambda": 1,
        "z_dim": 128,
        "dics_iters": disc_iters,
    }
    gin.bind_parameter("ModularGAN.g_use_ema", True)
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=self.model_dir)
    estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1)
    # Check for moving average variables in checkpoint.
    checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
    ema_vars = sorted([v[0] for v in tf.train.list_variables(checkpoint_path)
                       if v[0].endswith("ExponentialMovingAverage")])
    tf.logging.info("ema_vars=%s", ema_vars)
    expected_ema_vars = sorted([
        "generator/fc_noise/kernel/ExponentialMovingAverage",
        "generator/fc_noise/bias/ExponentialMovingAverage",
    ])
    self.assertAllEqual(ema_vars, expected_ema_vars) 
Example #8
Source File: core_test.py    From dm_control with Apache License 2.0 6 votes vote down vote up
def _get_attributes_test_params():
  model = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
  data = core.MjData(model)
  # Get the names of the non-private attributes of model and data through
  # introspection. These are passed as parameters to each of the test methods
  # in AttributesTest.
  array_args = []
  scalar_args = []
  skipped_args = []
  for parent_name, parent_obj in zip(("model", "data"), (model, data)):
    for attr_name in dir(parent_obj):
      if not attr_name.startswith("_"):  # Skip 'private' attributes
        args = (parent_name, attr_name)
        attr = getattr(parent_obj, attr_name)
        if isinstance(attr, ARRAY_TYPES):
          array_args.append(args)
        elif isinstance(attr, SCALAR_TYPES):
          scalar_args.append(args)
        elif callable(attr):
          # Methods etc. should be covered specifically in CoreTest.
          continue
        else:
          skipped_args.append(args)
  return array_args, scalar_args, skipped_args 
Example #9
Source File: distribution_ops_test.py    From trfl with Apache License 2.0 6 votes vote down vote up
def testFactorisedKLGaussian(self, dist1_type, dist2_type):
    """Tests that the factorised KL terms sum up to the true KL."""
    dist1, dist1_mean, dist1_cov = self._create_gaussian(dist1_type)
    dist2, dist2_mean, dist2_cov = self._create_gaussian(dist2_type)
    both_diagonal = _is_diagonal(dist1.scale) and _is_diagonal(dist2.scale)
    if both_diagonal:
      dist1_cov = dist1.parameters['scale_diag']
      dist2_cov = dist2.parameters['scale_diag']
    kl = tfp.distributions.kl_divergence(dist1, dist2)
    kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian(
        dist1_mean,
        dist1_cov,
        dist2_mean,
        dist2_cov,
        both_diagonal=both_diagonal)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      actual_kl, kl_mean_np, kl_cov_np = sess.run([kl, kl_mean, kl_cov])
      self.assertAllClose(actual_kl, kl_mean_np + kl_cov_np, rtol=1e-4) 
Example #10
Source File: modular_gan_conditional_test.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn,
                             labeled_dataset):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 120,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        conditional=True,
        model_dir=model_dir)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1) 
Example #11
Source File: function_utils_test.py    From federated with Apache License 2.0 6 votes vote down vote up
def test_get_defun_argspec_with_typed_non_eager_defun(self):
    # In a tf.function with a defined input signature, **kwargs or default
    # values are not allowed, but *args are, and the input signature may overlap
    # with *args.
    fn = tf.function(lambda x, y, *z: None, (
        tf.TensorSpec(None, tf.int32),
        tf.TensorSpec(None, tf.bool),
        tf.TensorSpec(None, tf.float32),
        tf.TensorSpec(None, tf.float32),
    ))
    self.assertEqual(
        collections.OrderedDict(function_utils.get_signature(fn).parameters),
        collections.OrderedDict(
            x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD),
            y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD),
            z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL),
        )) 
Example #12
Source File: graph_convolution_test.py    From graphics with Apache License 2.0 6 votes vote down vote up
def test_dynamic_graph_convolution_keras_layer_exception_not_raised_shapes(
      self, batch_size, num_vertices, in_channels, out_channels, reduction):
    """Check if the convolution parameters and output have correct shapes."""
    if not tf.executing_eagerly():
      return
    data, neighbors = _dummy_data(batch_size, num_vertices, in_channels)
    layer = gc_layer.DynamicGraphConvolutionKerasLayer(
        num_output_channels=out_channels,
        reduction=reduction)

    try:
      output = layer(inputs=[data, neighbors], sizes=None)
    except Exception as e:  # pylint: disable=broad-except
      self.fail("Exception raised: %s" % str(e))

    self.assertAllEqual((batch_size, num_vertices, out_channels), output.shape) 
Example #13
Source File: function_utils_test.py    From federated with Apache License 2.0 6 votes vote down vote up
def test_get_signature_with_class_instance_method(self):

    class C:

      def __init__(self, x):
        self._x = x

      def foo(self, y):
        return self._x * y

    c = C(5)
    signature = function_utils.get_signature(c.foo)
    self.assertEqual(
        signature.parameters,
        collections.OrderedDict(
            y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD))) 
Example #14
Source File: query_test.py    From python-spanner-orm with Apache License 2.0 5 votes vote down vote up
def test_includes(self):
    select_query = self.includes('parent')

    # The column order varies between test runs
    expected_sql = (
        r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* '
        r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* '
        r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = '
        r'RelationshipTestModel.parent_key\)')
    self.assertRegex(select_query.sql(), expected_sql)
    self.assertEmpty(select_query.parameters())
    self.assertEmpty(select_query.types()) 
Example #15
Source File: survey_api_test.py    From loaner with Apache License 2.0 5 votes vote down vote up
def _generate_message_parameters(want_permutations=False):
  """Generate message parameters for test cases.

  Args:
    want_permutations: bool, whether or not to run the messages through various
        permutations.

  Yields:
    A list containing the list of messages.
  """
  answer_message = survey_messages.Answer(
      text='Left my laptop at home.',
      more_info_enabled=False,
      placeholder_text=None)
  survey_messages_1 = survey_messages.Question(
      question_type=survey_models.QuestionType.ASSIGNMENT,
      question_text=_QUESTION.format(num=1),
      answers=[answer_message],
      rand_weight=1,
      required=True)
  survey_messages_2 = survey_messages.Question(
      question_type=survey_models.QuestionType.ASSIGNMENT,
      question_text=_QUESTION.format(num=2),
      answers=[answer_message],
      rand_weight=1,
      enabled=False,
      required=False)
  survey_messages_3 = survey_messages.Question(
      question_type=survey_models.QuestionType.RETURN,
      question_text=_QUESTION.format(num=3),
      answers=[answer_message],
      rand_weight=1,
      enabled=True)
  messages = [
      survey_messages_1, survey_messages_2,
      survey_messages_3]
  if want_permutations:
    for p in itertools.permutations(messages):
      yield [p]
  else:
    yield [messages] 
Example #16
Source File: query_test.py    From python-spanner-orm with Apache License 2.0 5 votes vote down vote up
def test_query_where_comparison(self, column, value, grpc_type):
    condition_generators = [
        condition.greater_than, condition.not_less_than, condition.less_than,
        condition.not_greater_than, condition.equal_to, condition.not_equal_to
    ]
    for condition_generator in condition_generators:
      current_condition = condition_generator(column, value)
      select_query = self.select(current_condition)

      column_key = '{}0'.format(column)
      expected_where = ' WHERE table.{} {} @{}'.format(
          column, current_condition.operator, column_key)
      self.assertEndsWith(select_query.sql(), expected_where)
      self.assertEqual(select_query.parameters(), {column_key: value})
      self.assertEqual(select_query.types(), {column_key: grpc_type}) 
Example #17
Source File: query_test.py    From python-spanner-orm with Apache License 2.0 5 votes vote down vote up
def test_query_where_comparison_with_object(self, column, value, grpc_type):
    condition_generators = [
        condition.greater_than, condition.not_less_than, condition.less_than,
        condition.not_greater_than, condition.equal_to, condition.not_equal_to
    ]
    for condition_generator in condition_generators:
      current_condition = condition_generator(column, value)
      select_query = self.select(current_condition)

      column_key = '{}0'.format(column.name)
      expected_where = ' WHERE table.{} {} @{}'.format(
          column.name, current_condition.operator, column_key)
      self.assertEndsWith(select_query.sql(), expected_where)
      self.assertEqual(select_query.parameters(), {column_key: value})
      self.assertEqual(select_query.types(), {column_key: grpc_type}) 
Example #18
Source File: query_test.py    From python-spanner-orm with Apache License 2.0 5 votes vote down vote up
def test_query_where_list_comparison(self, column, values, grpc_type):
    condition_generators = [condition.in_list, condition.not_in_list]
    for condition_generator in condition_generators:
      current_condition = condition_generator(column, values)
      select_query = self.select(current_condition)

      column_key = '{}0'.format(column)
      expected_sql = ' WHERE table.{} {} UNNEST(@{})'.format(
          column, current_condition.operator, column_key)
      list_type = type_pb2.Type(
          code=type_pb2.ARRAY, array_element_type=grpc_type)
      self.assertEndsWith(select_query.sql(), expected_sql)
      self.assertEqual(select_query.parameters(), {column_key: values})
      self.assertEqual(select_query.types(), {column_key: list_type}) 
Example #19
Source File: center_net_meta_arch_tf2_test.py    From models with Apache License 2.0 5 votes vote down vote up
def test_pad_to_full_instance_dim(self):
    batch_size = 4
    max_instances = 8
    num_keypoints = 6
    num_instances = 2
    instance_inds = [1, 3]

    kpt_coords_np = np.random.randn(batch_size, num_instances, num_keypoints, 2)
    kpt_scores_np = np.random.randn(batch_size, num_instances, num_keypoints)

    def graph_fn():
      kpt_coords = tf.constant(kpt_coords_np)
      kpt_scores = tf.constant(kpt_scores_np)
      kpt_coords_padded, kpt_scores_padded = (
          cnma._pad_to_full_instance_dim(
              kpt_coords, kpt_scores, instance_inds, max_instances))
      return kpt_coords_padded, kpt_scores_padded

    kpt_coords_padded, kpt_scores_padded = self.execute(graph_fn, [])

    self.assertAllEqual([batch_size, max_instances, num_keypoints, 2],
                        kpt_coords_padded.shape)
    self.assertAllEqual([batch_size, max_instances, num_keypoints],
                        kpt_scores_padded.shape)

    for i, inst_ind in enumerate(instance_inds):
      np.testing.assert_allclose(kpt_coords_np[:, i, :, :],
                                 kpt_coords_padded[:, inst_ind, :, :])
      np.testing.assert_allclose(kpt_scores_np[:, i, :],
                                 kpt_scores_padded[:, inst_ind, :])


# Common parameters for setting up testing examples across tests. 
Example #20
Source File: modular_gan_tpu_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def testBatchSizeExperimentalJointGenForDisc(self, disc_iters):
    parameters = {
        "architecture": c.DUMMY_ARCH,
        "lambda": 1,
        "z_dim": 128,
        "disc_iters": disc_iters,
    }
    batch_size = 16
    dataset = datasets.get_dataset("cifar10")
    gan = ModularGAN(
        dataset=dataset,
        parameters=parameters,
        experimental_joint_gen_for_disc=True,
        model_dir=self.model_dir)
    estimator = gan.as_estimator(self.run_config, batch_size=batch_size,
                                 use_tpu=True)
    estimator.train(gan.input_fn, steps=1)

    gen_args = gan.generator.call_arg_list
    disc_args = gan.discriminator.call_arg_list
    self.assertLen(gen_args, 2)
    self.assertLen(disc_args, disc_iters + 1)

    self.assertAllEqual(gen_args[0]["z"].shape.as_list(), [8 * disc_iters, 128])
    self.assertAllEqual(gen_args[1]["z"].shape.as_list(), [8, 128])
    for args in disc_args:
      self.assertAllEqual(args["x"].shape.as_list(), [16, 32, 32, 3]) 
Example #21
Source File: eval_gan_lib_test.py    From compare_gan with Apache License 2.0 5 votes vote down vote up
def test_end2end_checkpoint(self, architecture):
    """Takes real GAN (trained for 1 step) and evaluate it."""
    if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}:
      # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape.
      return
    gin.bind_parameter("dataset.name", "cifar10")
    dataset = datasets.get_dataset("cifar10")
    options = {
        "architecture": architecture,
        "z_dim": 120,
        "disc_iters": 1,
        "lambda": 1,
    }
    model_dir = os.path.join(tf.test.get_temp_dir(), self.id())
    tf.logging.info("model_dir: %s" % model_dir)
    run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir)
    gan = ModularGAN(dataset=dataset,
                     parameters=options,
                     conditional="biggan" in architecture,
                     model_dir=model_dir)
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(input_fn=gan.input_fn, steps=1)
    export_path = os.path.join(model_dir, "tfhub")
    checkpoint_path = os.path.join(model_dir, "model.ckpt-1")
    module_spec = gan.as_module_spec()
    module_spec.export(export_path, checkpoint_path=checkpoint_path)

    eval_tasks = [
        fid_score.FIDScoreTask(),
        fractal_dimension.FractalDimensionTask(),
        inception_score.InceptionScoreTask(),
        ms_ssim_score.MultiscaleSSIMTask()
    ]
    result_dict = eval_gan_lib.evaluate_tfhub_module(
        export_path, eval_tasks, use_tpu=False, num_averaging_runs=1)
    tf.logging.info("result_dict: %s", result_dict)
    for score in ["fid_score", "fractal_dimension", "inception_score",
                  "ms_ssim"]:
      for stats in ["mean", "std", "list"]:
        required_key = "%s_%s" % (score, stats)
        self.assertIn(required_key, result_dict, "Missing: %s." % required_key) 
Example #22
Source File: keras_layer_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def testBatchNormRetraining(self, save_from_keras):
    """Tests imported batch norm with trainable=True."""
    export_dir = os.path.join(self.get_temp_dir(), "batch-norm")
    _save_batch_norm_model(export_dir, save_from_keras=save_from_keras)
    estimator = tf.estimator.Estimator(
        model_fn=self._batch_norm_model_fn,
        params=dict(hub_module=export_dir, train_batch_norm=True))

    # Retrain the imported batch norm layer on a fixed batch of inputs,
    # which has mean 12.0 and some variance of a less obvious value.
    # The module learns scale and offset parameters that achieve the
    # mapping x --> 2*x for the observed mean and variance.
    x = [[11.], [12.], [13.]]
    y = [[2*xi[0]] for xi in x]
    train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        np.array(x, dtype=np.float32),
        np.array(y, dtype=np.float32),
        batch_size=len(x), num_epochs=None, shuffle=False)
    estimator.train(train_input_fn, steps=100)
    predictions = next(estimator.predict(train_input_fn,
                                         yield_single_examples=False))
    self.assertAllClose(predictions["mean"], np.array([12.0]))
    self.assertAllClose(predictions["beta"], np.array([24.0]))
    self.assertAllClose(predictions["output"], np.array(y))

    # Evaluating the model operates batch norm in inference mode:
    # - Batch statistics are ignored in favor of aggregated statistics,
    #   computing x --> 2*x independent of input distribution.
    # - Update ops are not run, so this doesn't change over time.
    predict_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        np.array([[10.], [20.], [30.]], dtype=np.float32),
        batch_size=3, num_epochs=100, shuffle=False)
    for predictions in estimator.predict(predict_input_fn,
                                         yield_single_examples=False):
      self.assertAllClose(predictions["output"],
                          np.array([[20.], [40.], [60.]]))
    self.assertAllClose(predictions["mean"], np.array([12.0]))
    self.assertAllClose(predictions["beta"], np.array([24.0])) 
Example #23
Source File: keras_layer_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def testBatchNormRetraining(self, save_from_keras):
    """Tests imported batch norm with trainable=True."""
    export_dir = os.path.join(self.get_temp_dir(), "batch-norm")
    _save_batch_norm_model(export_dir, save_from_keras=save_from_keras)
    inp = tf.keras.layers.Input(shape=(1,), dtype=tf.float32)
    imported = hub.KerasLayer(export_dir, trainable=True)
    var_beta, var_gamma, var_mean, var_variance = _get_batch_norm_vars(imported)
    outp = imported(inp)
    model = tf.keras.Model(inp, outp)
    # Retrain the imported batch norm layer on a fixed batch of inputs,
    # which has mean 12.0 and some variance of a less obvious value.
    # The module learns scale and offset parameters that achieve the
    # mapping x --> 2*x for the observed mean and variance.
    model.compile(tf.keras.optimizers.SGD(0.1),
                  "mean_squared_error", run_eagerly=True)
    x = [[11.], [12.], [13.]]
    y = [[2*xi[0]] for xi in x]
    model.fit(np.array(x), np.array(y), batch_size=len(x), epochs=100)
    self.assertAllClose(var_mean.numpy(), np.array([12.0]))
    self.assertAllClose(var_beta.numpy(), np.array([24.0]))
    self.assertAllClose(model(np.array(x, np.float32)), np.array(y))
    # Evaluating the model operates batch norm in inference mode:
    # - Batch statistics are ignored in favor of aggregated statistics,
    #   computing x --> 2*x independent of input distribution.
    # - Update ops are not run, so this doesn't change over time.
    for _ in range(100):
      self.assertAllClose(model(np.array([[10.], [20.], [30.]], np.float32)),
                          np.array([[20.], [40.], [60.]]))
    self.assertAllClose(var_mean.numpy(), np.array([12.0]))
    self.assertAllClose(var_beta.numpy(), np.array([24.0])) 
Example #24
Source File: distribution_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testConsistentGradientsFullCovariance(self):
    dist_type = tfp.distributions.MultivariateNormalFullCovariance
    dist1, dist1_mean, dist1_cov = self._create_gaussian(dist_type)
    dist2, dist2_mean, dist2_cov = self._create_gaussian(dist_type)

    kl = tfp.distributions.kl_divergence(dist1, dist2)
    kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian(
        dist1_mean, dist1_cov, dist2_mean, dist2_cov, both_diagonal=False)

    dist1_cov = dist1.parameters['covariance_matrix']
    dist2_cov = dist2.parameters['covariance_matrix']
    dist_params = [
        dist1_mean,
        dist2_mean,
        dist1_cov,
        dist2_cov,
    ]
    actual_kl_gradients = tf.gradients(kl, dist_params)
    factorised_kl_gradients = tf.gradients(kl_mean + kl_cov, dist_params)

    # Check that no gradients flow into the mean terms from `kl_cov` and
    # vice-versa.
    gradients = tf.gradients(kl_mean, [dist1_cov])
    self.assertListEqual(gradients, [None])
    gradients = tf.gradients(kl_cov, [dist1_mean, dist2_mean])
    self.assertListEqual(gradients, [None, None])

    with self.test_session() as sess:
      np_actual_kl, np_factorised_kl = sess.run(
          [actual_kl_gradients, factorised_kl_gradients])
      self.assertAllClose(np_actual_kl, np_factorised_kl)


# Check for diagonal Gaussian distributions. Based on the definition in
# tensorflow_probability/python/distributions/mvn_linear_operator.py 
Example #25
Source File: distribution_ops_test.py    From trfl with Apache License 2.0 5 votes vote down vote up
def testConsistentGradientsBothDiagonal(self):
    dist_type = tfp.distributions.MultivariateNormalDiag
    dist1, dist1_mean, _ = self._create_gaussian(dist_type)
    dist2, dist2_mean, _ = self._create_gaussian(dist_type)

    kl = tfp.distributions.kl_divergence(dist1, dist2)
    dist1_scale = dist1.parameters['scale_diag']
    dist2_scale = dist2.parameters['scale_diag']
    kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian(
        dist1_mean, dist1_scale, dist2_mean, dist2_scale, both_diagonal=True)

    dist_params = [dist1_mean, dist2_mean, dist1_scale, dist2_scale]
    actual_kl_gradients = tf.gradients(kl, dist_params)
    factorised_kl_gradients = tf.gradients(kl_mean + kl_cov, dist_params)

    # Check that no gradients flow into the mean terms from `kl_cov` and
    # vice-versa.
    gradients = tf.gradients(kl_mean, [dist1_scale])
    self.assertListEqual(gradients, [None])
    gradients = tf.gradients(kl_cov, [dist1_mean, dist2_mean])
    self.assertListEqual(gradients, [None, None])

    with self.test_session() as sess:
      np_actual_kl, np_factorised_kl = sess.run(
          [actual_kl_gradients, factorised_kl_gradients])
      self.assertAllClose(np_actual_kl, np_factorised_kl) 
Example #26
Source File: parameterized_test.py    From abseil-py with Apache License 2.0 5 votes vote down vote up
def test_no_test_error_empty_generator(self):
    with self.assertRaises(parameterized.NoTestsError):

      @parameterized.parameters((i for i in []))
      def test_something():
        pass

      del test_something 
Example #27
Source File: parameterized_test.py    From abseil-py with Apache License 2.0 5 votes vote down vote up
def test_no_test_error_empty_parameters(self):
    with self.assertRaises(parameterized.NoTestsError):

      @parameterized.parameters()
      def test_something():
        pass

      del test_something 
Example #28
Source File: parameterized_test.py    From abseil-py with Apache License 2.0 5 votes vote down vote up
def tes_double_class_decorations_not_supported(self):

    @parameterized.parameters('foo', 'bar')
    class SuperclassWithClassDecorator(parameterized.TestCase):

      def test_name(self, name):
        del name

    with self.assertRaises(AssertionError):

      @parameterized.parameters('foo', 'bar')
      class SubclassWithClassDecorator(SuperclassWithClassDecorator):
        pass

      del SubclassWithClassDecorator 
Example #29
Source File: parameterized_test.py    From abseil-py with Apache License 2.0 5 votes vote down vote up
def test_no_duplicate_decorations(self):
    with self.assertRaises(AssertionError):

      @parameterized.parameters(1, 2, 3, 4)
      class _(parameterized.TestCase):

        @parameterized.parameters(5, 6, 7, 8)
        def test_something(self, unused_obj):
          pass 
Example #30
Source File: parameterized_test.py    From abseil-py with Apache License 2.0 5 votes vote down vote up
def test_parameterized_test_iter_has_testcases_property(self):
    @parameterized.parameters(1, 2, 3, 4, 5, 6)
    def test_something(unused_self, unused_obj):  # pylint: disable=invalid-name
      pass

    expected_testcases = [1, 2, 3, 4, 5, 6]
    self.assertTrue(hasattr(test_something, 'testcases'))
    self.assertItemsEqual(expected_testcases, test_something.testcases)