Python absl.testing.parameterized.parameters() Examples
The following are 30
code examples of absl.testing.parameterized.parameters().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
absl.testing.parameterized
, or try the search function
.
Example #1
Source File: ssgan_test.py From compare_gan with Apache License 2.0 | 7 votes |
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 128, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = SSGAN( dataset=dataset, parameters=parameters, model_dir=model_dir, g_optimizer_fn=tf.train.AdamOptimizer, g_lr=0.0002, rotated_batch_size=4) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #2
Source File: parameterized.py From abseil-py with Apache License 2.0 | 6 votes |
def parameters(*testcases): """A decorator for creating parameterized tests. See the module docstring for a usage example. Args: *testcases: Parameters for the decorated method, either a single iterable, or a list of tuples/dicts/objects (for tests with only one argument). Raises: NoTestsError: Raised when the decorator generates no tests. Returns: A test generator to be handled by TestGeneratorMetaclass. """ return _parameter_decorator(_ARGUMENT_REPR, testcases)
Example #3
Source File: test_utils_test.py From model-optimization with Apache License 2.0 | 6 votes |
def test_basic_encode_decode_tf_constructor_parameters(self): """Tests the core funcionality with `tf.Variable` constructor parameters.""" a_var = tf.compat.v1.get_variable('a_var', initializer=self._DEFAULT_A) b_var = tf.compat.v1.get_variable('b_var', initializer=self._DEFAULT_B) stage = test_utils.SimpleLinearEncodingStage(a_var, b_var) with self.cached_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) x = self.default_input() encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = self.evaluate_test_data( test_utils.TestData(x, encoded_x, decoded_x)) self.common_asserts_for_test_data(test_data) # Change the variables and verify the behavior of stage changes. self.evaluate( [tf.compat.v1.assign(a_var, 5.0), tf.compat.v1.assign(b_var, 6.0)]) test_data = self.evaluate_test_data( test_utils.TestData(x, encoded_x, decoded_x)) self.assertAllClose(test_data.x * 5.0 + 6.0, test_data.encoded_x[self._ENCODED_VALUES_KEY])
Example #4
Source File: modular_gan_test.py From compare_gan with Apache License 2.0 | 6 votes |
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 128, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) dataset = datasets.get_dataset("cifar10") gan = ModularGAN( dataset=dataset, parameters=parameters, model_dir=self.model_dir, conditional="biggan" in architecture) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #5
Source File: parameterized.py From abseil-py with Apache License 2.0 | 6 votes |
def _modify_class(class_object, testcases, naming_type): assert not getattr(class_object, '_test_method_ids', None), ( 'Cannot add parameters to %s. Either it already has parameterized ' 'methods, or its super class is also a parameterized class.' % ( class_object,)) class_object._test_method_ids = test_method_ids = {} for name, obj in six.iteritems(class_object.__dict__.copy()): if (name.startswith(unittest.TestLoader.testMethodPrefix) and isinstance(obj, types.FunctionType)): delattr(class_object, name) methods = {} _update_class_dict_for_param_test_case( class_object.__name__, methods, test_method_ids, name, _ParameterizedTestIter(obj, testcases, naming_type, name)) for name, meth in six.iteritems(methods): setattr(class_object, name, meth)
Example #6
Source File: modular_gan_test.py From compare_gan with Apache License 2.0 | 6 votes |
def testSingleTrainingStepWithJointGenForDisc(self): parameters = { "architecture": c.DUMMY_ARCH, "lambda": 1, "z_dim": 120, "disc_iters": 2, } dataset = datasets.get_dataset("cifar10") gan = ModularGAN( dataset=dataset, parameters=parameters, model_dir=self.model_dir, experimental_joint_gen_for_disc=True, experimental_force_graph_unroll=True, conditional=True) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #7
Source File: modular_gan_test.py From compare_gan with Apache License 2.0 | 6 votes |
def testSingleTrainingStepDiscItersWithEma(self, disc_iters): parameters = { "architecture": c.DUMMY_ARCH, "lambda": 1, "z_dim": 128, "dics_iters": disc_iters, } gin.bind_parameter("ModularGAN.g_use_ema", True) dataset = datasets.get_dataset("cifar10") gan = ModularGAN( dataset=dataset, parameters=parameters, model_dir=self.model_dir) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1) # Check for moving average variables in checkpoint. checkpoint_path = tf.train.latest_checkpoint(self.model_dir) ema_vars = sorted([v[0] for v in tf.train.list_variables(checkpoint_path) if v[0].endswith("ExponentialMovingAverage")]) tf.logging.info("ema_vars=%s", ema_vars) expected_ema_vars = sorted([ "generator/fc_noise/kernel/ExponentialMovingAverage", "generator/fc_noise/bias/ExponentialMovingAverage", ]) self.assertAllEqual(ema_vars, expected_ema_vars)
Example #8
Source File: core_test.py From dm_control with Apache License 2.0 | 6 votes |
def _get_attributes_test_params(): model = core.MjModel.from_xml_path(HUMANOID_XML_PATH) data = core.MjData(model) # Get the names of the non-private attributes of model and data through # introspection. These are passed as parameters to each of the test methods # in AttributesTest. array_args = [] scalar_args = [] skipped_args = [] for parent_name, parent_obj in zip(("model", "data"), (model, data)): for attr_name in dir(parent_obj): if not attr_name.startswith("_"): # Skip 'private' attributes args = (parent_name, attr_name) attr = getattr(parent_obj, attr_name) if isinstance(attr, ARRAY_TYPES): array_args.append(args) elif isinstance(attr, SCALAR_TYPES): scalar_args.append(args) elif callable(attr): # Methods etc. should be covered specifically in CoreTest. continue else: skipped_args.append(args) return array_args, scalar_args, skipped_args
Example #9
Source File: distribution_ops_test.py From trfl with Apache License 2.0 | 6 votes |
def testFactorisedKLGaussian(self, dist1_type, dist2_type): """Tests that the factorised KL terms sum up to the true KL.""" dist1, dist1_mean, dist1_cov = self._create_gaussian(dist1_type) dist2, dist2_mean, dist2_cov = self._create_gaussian(dist2_type) both_diagonal = _is_diagonal(dist1.scale) and _is_diagonal(dist2.scale) if both_diagonal: dist1_cov = dist1.parameters['scale_diag'] dist2_cov = dist2.parameters['scale_diag'] kl = tfp.distributions.kl_divergence(dist1, dist2) kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian( dist1_mean, dist1_cov, dist2_mean, dist2_cov, both_diagonal=both_diagonal) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) actual_kl, kl_mean_np, kl_cov_np = sess.run([kl, kl_mean, kl_cov]) self.assertAllClose(actual_kl, kl_mean_np + kl_cov_np, rtol=1e-4)
Example #10
Source File: modular_gan_conditional_test.py From compare_gan with Apache License 2.0 | 6 votes |
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn, labeled_dataset): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 120, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = ModularGAN( dataset=dataset, parameters=parameters, conditional=True, model_dir=model_dir) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
Example #11
Source File: function_utils_test.py From federated with Apache License 2.0 | 6 votes |
def test_get_defun_argspec_with_typed_non_eager_defun(self): # In a tf.function with a defined input signature, **kwargs or default # values are not allowed, but *args are, and the input signature may overlap # with *args. fn = tf.function(lambda x, y, *z: None, ( tf.TensorSpec(None, tf.int32), tf.TensorSpec(None, tf.bool), tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32), )) self.assertEqual( collections.OrderedDict(function_utils.get_signature(fn).parameters), collections.OrderedDict( x=inspect.Parameter('x', inspect.Parameter.POSITIONAL_OR_KEYWORD), y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD), z=inspect.Parameter('z', inspect.Parameter.VAR_POSITIONAL), ))
Example #12
Source File: graph_convolution_test.py From graphics with Apache License 2.0 | 6 votes |
def test_dynamic_graph_convolution_keras_layer_exception_not_raised_shapes( self, batch_size, num_vertices, in_channels, out_channels, reduction): """Check if the convolution parameters and output have correct shapes.""" if not tf.executing_eagerly(): return data, neighbors = _dummy_data(batch_size, num_vertices, in_channels) layer = gc_layer.DynamicGraphConvolutionKerasLayer( num_output_channels=out_channels, reduction=reduction) try: output = layer(inputs=[data, neighbors], sizes=None) except Exception as e: # pylint: disable=broad-except self.fail("Exception raised: %s" % str(e)) self.assertAllEqual((batch_size, num_vertices, out_channels), output.shape)
Example #13
Source File: function_utils_test.py From federated with Apache License 2.0 | 6 votes |
def test_get_signature_with_class_instance_method(self): class C: def __init__(self, x): self._x = x def foo(self, y): return self._x * y c = C(5) signature = function_utils.get_signature(c.foo) self.assertEqual( signature.parameters, collections.OrderedDict( y=inspect.Parameter('y', inspect.Parameter.POSITIONAL_OR_KEYWORD)))
Example #14
Source File: query_test.py From python-spanner-orm with Apache License 2.0 | 5 votes |
def test_includes(self): select_query = self.includes('parent') # The column order varies between test runs expected_sql = ( r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' r'RelationshipTestModel.parent_key\)') self.assertRegex(select_query.sql(), expected_sql) self.assertEmpty(select_query.parameters()) self.assertEmpty(select_query.types())
Example #15
Source File: survey_api_test.py From loaner with Apache License 2.0 | 5 votes |
def _generate_message_parameters(want_permutations=False): """Generate message parameters for test cases. Args: want_permutations: bool, whether or not to run the messages through various permutations. Yields: A list containing the list of messages. """ answer_message = survey_messages.Answer( text='Left my laptop at home.', more_info_enabled=False, placeholder_text=None) survey_messages_1 = survey_messages.Question( question_type=survey_models.QuestionType.ASSIGNMENT, question_text=_QUESTION.format(num=1), answers=[answer_message], rand_weight=1, required=True) survey_messages_2 = survey_messages.Question( question_type=survey_models.QuestionType.ASSIGNMENT, question_text=_QUESTION.format(num=2), answers=[answer_message], rand_weight=1, enabled=False, required=False) survey_messages_3 = survey_messages.Question( question_type=survey_models.QuestionType.RETURN, question_text=_QUESTION.format(num=3), answers=[answer_message], rand_weight=1, enabled=True) messages = [ survey_messages_1, survey_messages_2, survey_messages_3] if want_permutations: for p in itertools.permutations(messages): yield [p] else: yield [messages]
Example #16
Source File: query_test.py From python-spanner-orm with Apache License 2.0 | 5 votes |
def test_query_where_comparison(self, column, value, grpc_type): condition_generators = [ condition.greater_than, condition.not_less_than, condition.less_than, condition.not_greater_than, condition.equal_to, condition.not_equal_to ] for condition_generator in condition_generators: current_condition = condition_generator(column, value) select_query = self.select(current_condition) column_key = '{}0'.format(column) expected_where = ' WHERE table.{} {} @{}'.format( column, current_condition.operator, column_key) self.assertEndsWith(select_query.sql(), expected_where) self.assertEqual(select_query.parameters(), {column_key: value}) self.assertEqual(select_query.types(), {column_key: grpc_type})
Example #17
Source File: query_test.py From python-spanner-orm with Apache License 2.0 | 5 votes |
def test_query_where_comparison_with_object(self, column, value, grpc_type): condition_generators = [ condition.greater_than, condition.not_less_than, condition.less_than, condition.not_greater_than, condition.equal_to, condition.not_equal_to ] for condition_generator in condition_generators: current_condition = condition_generator(column, value) select_query = self.select(current_condition) column_key = '{}0'.format(column.name) expected_where = ' WHERE table.{} {} @{}'.format( column.name, current_condition.operator, column_key) self.assertEndsWith(select_query.sql(), expected_where) self.assertEqual(select_query.parameters(), {column_key: value}) self.assertEqual(select_query.types(), {column_key: grpc_type})
Example #18
Source File: query_test.py From python-spanner-orm with Apache License 2.0 | 5 votes |
def test_query_where_list_comparison(self, column, values, grpc_type): condition_generators = [condition.in_list, condition.not_in_list] for condition_generator in condition_generators: current_condition = condition_generator(column, values) select_query = self.select(current_condition) column_key = '{}0'.format(column) expected_sql = ' WHERE table.{} {} UNNEST(@{})'.format( column, current_condition.operator, column_key) list_type = type_pb2.Type( code=type_pb2.ARRAY, array_element_type=grpc_type) self.assertEndsWith(select_query.sql(), expected_sql) self.assertEqual(select_query.parameters(), {column_key: values}) self.assertEqual(select_query.types(), {column_key: list_type})
Example #19
Source File: center_net_meta_arch_tf2_test.py From models with Apache License 2.0 | 5 votes |
def test_pad_to_full_instance_dim(self): batch_size = 4 max_instances = 8 num_keypoints = 6 num_instances = 2 instance_inds = [1, 3] kpt_coords_np = np.random.randn(batch_size, num_instances, num_keypoints, 2) kpt_scores_np = np.random.randn(batch_size, num_instances, num_keypoints) def graph_fn(): kpt_coords = tf.constant(kpt_coords_np) kpt_scores = tf.constant(kpt_scores_np) kpt_coords_padded, kpt_scores_padded = ( cnma._pad_to_full_instance_dim( kpt_coords, kpt_scores, instance_inds, max_instances)) return kpt_coords_padded, kpt_scores_padded kpt_coords_padded, kpt_scores_padded = self.execute(graph_fn, []) self.assertAllEqual([batch_size, max_instances, num_keypoints, 2], kpt_coords_padded.shape) self.assertAllEqual([batch_size, max_instances, num_keypoints], kpt_scores_padded.shape) for i, inst_ind in enumerate(instance_inds): np.testing.assert_allclose(kpt_coords_np[:, i, :, :], kpt_coords_padded[:, inst_ind, :, :]) np.testing.assert_allclose(kpt_scores_np[:, i, :], kpt_scores_padded[:, inst_ind, :]) # Common parameters for setting up testing examples across tests.
Example #20
Source File: modular_gan_tpu_test.py From compare_gan with Apache License 2.0 | 5 votes |
def testBatchSizeExperimentalJointGenForDisc(self, disc_iters): parameters = { "architecture": c.DUMMY_ARCH, "lambda": 1, "z_dim": 128, "disc_iters": disc_iters, } batch_size = 16 dataset = datasets.get_dataset("cifar10") gan = ModularGAN( dataset=dataset, parameters=parameters, experimental_joint_gen_for_disc=True, model_dir=self.model_dir) estimator = gan.as_estimator(self.run_config, batch_size=batch_size, use_tpu=True) estimator.train(gan.input_fn, steps=1) gen_args = gan.generator.call_arg_list disc_args = gan.discriminator.call_arg_list self.assertLen(gen_args, 2) self.assertLen(disc_args, disc_iters + 1) self.assertAllEqual(gen_args[0]["z"].shape.as_list(), [8 * disc_iters, 128]) self.assertAllEqual(gen_args[1]["z"].shape.as_list(), [8, 128]) for args in disc_args: self.assertAllEqual(args["x"].shape.as_list(), [16, 32, 32, 3])
Example #21
Source File: eval_gan_lib_test.py From compare_gan with Apache License 2.0 | 5 votes |
def test_end2end_checkpoint(self, architecture): """Takes real GAN (trained for 1 step) and evaluate it.""" if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}: # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape. return gin.bind_parameter("dataset.name", "cifar10") dataset = datasets.get_dataset("cifar10") options = { "architecture": architecture, "z_dim": 120, "disc_iters": 1, "lambda": 1, } model_dir = os.path.join(tf.test.get_temp_dir(), self.id()) tf.logging.info("model_dir: %s" % model_dir) run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir) gan = ModularGAN(dataset=dataset, parameters=options, conditional="biggan" in architecture, model_dir=model_dir) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(input_fn=gan.input_fn, steps=1) export_path = os.path.join(model_dir, "tfhub") checkpoint_path = os.path.join(model_dir, "model.ckpt-1") module_spec = gan.as_module_spec() module_spec.export(export_path, checkpoint_path=checkpoint_path) eval_tasks = [ fid_score.FIDScoreTask(), fractal_dimension.FractalDimensionTask(), inception_score.InceptionScoreTask(), ms_ssim_score.MultiscaleSSIMTask() ] result_dict = eval_gan_lib.evaluate_tfhub_module( export_path, eval_tasks, use_tpu=False, num_averaging_runs=1) tf.logging.info("result_dict: %s", result_dict) for score in ["fid_score", "fractal_dimension", "inception_score", "ms_ssim"]: for stats in ["mean", "std", "list"]: required_key = "%s_%s" % (score, stats) self.assertIn(required_key, result_dict, "Missing: %s." % required_key)
Example #22
Source File: keras_layer_test.py From hub with Apache License 2.0 | 5 votes |
def testBatchNormRetraining(self, save_from_keras): """Tests imported batch norm with trainable=True.""" export_dir = os.path.join(self.get_temp_dir(), "batch-norm") _save_batch_norm_model(export_dir, save_from_keras=save_from_keras) estimator = tf.estimator.Estimator( model_fn=self._batch_norm_model_fn, params=dict(hub_module=export_dir, train_batch_norm=True)) # Retrain the imported batch norm layer on a fixed batch of inputs, # which has mean 12.0 and some variance of a less obvious value. # The module learns scale and offset parameters that achieve the # mapping x --> 2*x for the observed mean and variance. x = [[11.], [12.], [13.]] y = [[2*xi[0]] for xi in x] train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( np.array(x, dtype=np.float32), np.array(y, dtype=np.float32), batch_size=len(x), num_epochs=None, shuffle=False) estimator.train(train_input_fn, steps=100) predictions = next(estimator.predict(train_input_fn, yield_single_examples=False)) self.assertAllClose(predictions["mean"], np.array([12.0])) self.assertAllClose(predictions["beta"], np.array([24.0])) self.assertAllClose(predictions["output"], np.array(y)) # Evaluating the model operates batch norm in inference mode: # - Batch statistics are ignored in favor of aggregated statistics, # computing x --> 2*x independent of input distribution. # - Update ops are not run, so this doesn't change over time. predict_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( np.array([[10.], [20.], [30.]], dtype=np.float32), batch_size=3, num_epochs=100, shuffle=False) for predictions in estimator.predict(predict_input_fn, yield_single_examples=False): self.assertAllClose(predictions["output"], np.array([[20.], [40.], [60.]])) self.assertAllClose(predictions["mean"], np.array([12.0])) self.assertAllClose(predictions["beta"], np.array([24.0]))
Example #23
Source File: keras_layer_test.py From hub with Apache License 2.0 | 5 votes |
def testBatchNormRetraining(self, save_from_keras): """Tests imported batch norm with trainable=True.""" export_dir = os.path.join(self.get_temp_dir(), "batch-norm") _save_batch_norm_model(export_dir, save_from_keras=save_from_keras) inp = tf.keras.layers.Input(shape=(1,), dtype=tf.float32) imported = hub.KerasLayer(export_dir, trainable=True) var_beta, var_gamma, var_mean, var_variance = _get_batch_norm_vars(imported) outp = imported(inp) model = tf.keras.Model(inp, outp) # Retrain the imported batch norm layer on a fixed batch of inputs, # which has mean 12.0 and some variance of a less obvious value. # The module learns scale and offset parameters that achieve the # mapping x --> 2*x for the observed mean and variance. model.compile(tf.keras.optimizers.SGD(0.1), "mean_squared_error", run_eagerly=True) x = [[11.], [12.], [13.]] y = [[2*xi[0]] for xi in x] model.fit(np.array(x), np.array(y), batch_size=len(x), epochs=100) self.assertAllClose(var_mean.numpy(), np.array([12.0])) self.assertAllClose(var_beta.numpy(), np.array([24.0])) self.assertAllClose(model(np.array(x, np.float32)), np.array(y)) # Evaluating the model operates batch norm in inference mode: # - Batch statistics are ignored in favor of aggregated statistics, # computing x --> 2*x independent of input distribution. # - Update ops are not run, so this doesn't change over time. for _ in range(100): self.assertAllClose(model(np.array([[10.], [20.], [30.]], np.float32)), np.array([[20.], [40.], [60.]])) self.assertAllClose(var_mean.numpy(), np.array([12.0])) self.assertAllClose(var_beta.numpy(), np.array([24.0]))
Example #24
Source File: distribution_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testConsistentGradientsFullCovariance(self): dist_type = tfp.distributions.MultivariateNormalFullCovariance dist1, dist1_mean, dist1_cov = self._create_gaussian(dist_type) dist2, dist2_mean, dist2_cov = self._create_gaussian(dist_type) kl = tfp.distributions.kl_divergence(dist1, dist2) kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian( dist1_mean, dist1_cov, dist2_mean, dist2_cov, both_diagonal=False) dist1_cov = dist1.parameters['covariance_matrix'] dist2_cov = dist2.parameters['covariance_matrix'] dist_params = [ dist1_mean, dist2_mean, dist1_cov, dist2_cov, ] actual_kl_gradients = tf.gradients(kl, dist_params) factorised_kl_gradients = tf.gradients(kl_mean + kl_cov, dist_params) # Check that no gradients flow into the mean terms from `kl_cov` and # vice-versa. gradients = tf.gradients(kl_mean, [dist1_cov]) self.assertListEqual(gradients, [None]) gradients = tf.gradients(kl_cov, [dist1_mean, dist2_mean]) self.assertListEqual(gradients, [None, None]) with self.test_session() as sess: np_actual_kl, np_factorised_kl = sess.run( [actual_kl_gradients, factorised_kl_gradients]) self.assertAllClose(np_actual_kl, np_factorised_kl) # Check for diagonal Gaussian distributions. Based on the definition in # tensorflow_probability/python/distributions/mvn_linear_operator.py
Example #25
Source File: distribution_ops_test.py From trfl with Apache License 2.0 | 5 votes |
def testConsistentGradientsBothDiagonal(self): dist_type = tfp.distributions.MultivariateNormalDiag dist1, dist1_mean, _ = self._create_gaussian(dist_type) dist2, dist2_mean, _ = self._create_gaussian(dist_type) kl = tfp.distributions.kl_divergence(dist1, dist2) dist1_scale = dist1.parameters['scale_diag'] dist2_scale = dist2.parameters['scale_diag'] kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian( dist1_mean, dist1_scale, dist2_mean, dist2_scale, both_diagonal=True) dist_params = [dist1_mean, dist2_mean, dist1_scale, dist2_scale] actual_kl_gradients = tf.gradients(kl, dist_params) factorised_kl_gradients = tf.gradients(kl_mean + kl_cov, dist_params) # Check that no gradients flow into the mean terms from `kl_cov` and # vice-versa. gradients = tf.gradients(kl_mean, [dist1_scale]) self.assertListEqual(gradients, [None]) gradients = tf.gradients(kl_cov, [dist1_mean, dist2_mean]) self.assertListEqual(gradients, [None, None]) with self.test_session() as sess: np_actual_kl, np_factorised_kl = sess.run( [actual_kl_gradients, factorised_kl_gradients]) self.assertAllClose(np_actual_kl, np_factorised_kl)
Example #26
Source File: parameterized_test.py From abseil-py with Apache License 2.0 | 5 votes |
def test_no_test_error_empty_generator(self): with self.assertRaises(parameterized.NoTestsError): @parameterized.parameters((i for i in [])) def test_something(): pass del test_something
Example #27
Source File: parameterized_test.py From abseil-py with Apache License 2.0 | 5 votes |
def test_no_test_error_empty_parameters(self): with self.assertRaises(parameterized.NoTestsError): @parameterized.parameters() def test_something(): pass del test_something
Example #28
Source File: parameterized_test.py From abseil-py with Apache License 2.0 | 5 votes |
def tes_double_class_decorations_not_supported(self): @parameterized.parameters('foo', 'bar') class SuperclassWithClassDecorator(parameterized.TestCase): def test_name(self, name): del name with self.assertRaises(AssertionError): @parameterized.parameters('foo', 'bar') class SubclassWithClassDecorator(SuperclassWithClassDecorator): pass del SubclassWithClassDecorator
Example #29
Source File: parameterized_test.py From abseil-py with Apache License 2.0 | 5 votes |
def test_no_duplicate_decorations(self): with self.assertRaises(AssertionError): @parameterized.parameters(1, 2, 3, 4) class _(parameterized.TestCase): @parameterized.parameters(5, 6, 7, 8) def test_something(self, unused_obj): pass
Example #30
Source File: parameterized_test.py From abseil-py with Apache License 2.0 | 5 votes |
def test_parameterized_test_iter_has_testcases_property(self): @parameterized.parameters(1, 2, 3, 4, 5, 6) def test_something(unused_self, unused_obj): # pylint: disable=invalid-name pass expected_testcases = [1, 2, 3, 4, 5, 6] self.assertTrue(hasattr(test_something, 'testcases')) self.assertItemsEqual(expected_testcases, test_something.testcases)