Python tensorflow_hub.create_module_spec() Examples

The following are 30 code examples of tensorflow_hub.create_module_spec(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow_hub , or try the search function .
Example #1
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testInputsFromMultivaluedOp(self):
    """Tests warning for inputs from multivalued ops in create_module_spec."""
    # Ideally, one would be able to write
    #    with self.assertLogs("blah"): hub.create_module_spec(module_fn)
    # but in the absence of assertions on logs, we test the underlying helper
    # in the environment seen from within a module_fn.
    with tf.Graph().as_default():
      first, _ = tf.split([[1, 2], [3, 4]], 2, name="split1")
      _, second = tf.split([[5, 6], [7, 8]], 2, name="split2")
      third = tf.constant(105, name="const")
      message = native_module.find_signature_inputs_from_multivalued_ops(
          dict(first=first, second=second, third=third))
    self.assertRegexpMatches(
        message,
        ".*single output.*\n"
        "Affected inputs: first='split1:0', second='split2:1'$")
    # Also test the case of no errors.
    with tf.Graph().as_default():
      first = tf.constant(101)
      second = tf.constant(102)
      third = tf.constant(103)
      message = native_module.find_signature_inputs_from_multivalued_ops(
          dict(first=first, second=second, third=third))
    self.assertIsNone(message) 
Example #2
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testSparseTensors(self):
    square_spec = hub.create_module_spec(sparse_square_module_fn)

    with tf.Graph().as_default():
      square = hub.Module(square_spec)
      v = tf_v1.sparse_placeholder(dtype=tf.int64, name="v")
      y = square(v)

      with tf_v1.Session().as_default():
        indices = [[0, 0], [0, 1], [1, 1]]
        values = [10, 2, 1]
        shape = [2, 2]
        v1 = tf_v1.SparseTensorValue(indices, values, shape)
        v2 = y.eval(feed_dict={v: v1})
        v4 = y.eval(feed_dict={v: v2})

        self.assertAllEqual(v4.indices, indices)  # Unchanged.
        self.assertAllEqual(v4.values, [t**4 for t in values])  # Squared twice.
        self.assertAllEqual(v4.dense_shape, shape)  # Unchanged. 
Example #3
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testDuplicateAssetCopy(self):
    export_path = os.path.join(self.get_temp_dir(), "assets-module")

    def module_with_duplicate_asset():
      vocabulary_file = self.create_vocab_file("tokens2.txt", ["1", "2", "3"])
      indices1 = tf_v1.placeholder(dtype=tf.int64, name="indices1")
      indices2 = tf_v1.placeholder(dtype=tf.int64, name="indices2")
      hub.add_signature(
          inputs={
              "indices_1": indices1,
              "indices_2": indices2,
          },
          outputs={
              "x": do_table_lookup(indices1, vocabulary_file),
              "y": do_table_lookup(indices2, vocabulary_file),
          })

    with tf.Graph().as_default():
      spec = hub.create_module_spec(module_with_duplicate_asset)
      module_a = hub.Module(spec)
      module_a({"indices_1": tf.constant([1, 2], dtype=tf.int64),
                "indices_2": tf.constant([1, 2], dtype=tf.int64)}, as_dict=True)
      with tf_v1.Session() as sess:
        sess.run(tf_v1.tables_initializer())
        module_a.export(export_path, sess) 
Example #4
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testNonResourceVariableInWhileLoop(self):
    with tf.Graph().as_default():
      # This test uses non-Resource variables to see an actual colocation
      # constraint propagated to the context Enter op. The long comment on
      # colocation in testResourceVariables explains why they may not offer
      # that.
      spec = hub.create_module_spec(stateful_non_rv_module_fn)
      m = hub.Module(spec)
      cond = lambda i, x: tf.less(i, 4)
      def body(i, x):
        v = m()
        self.assertItemsEqual(v.op.colocation_groups(),
                              [tf.compat.as_bytes("loc:@module/var123")])
        return (tf.add(i, 1), 2*x)
      oi, ox = tf.while_loop(cond, body, [0, 10.0])
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertAllEqual(sess.run([oi, ox]), [4, 160.0]) 
Example #5
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testUseWithinWhileLoop(self):
    with tf.Graph().as_default():
      spec = hub.create_module_spec(double_module_fn)
      m = hub.Module(spec)
      i = tf.constant(0)
      x = tf.constant(10.0)
      p = tf_v1.placeholder(dtype=tf.int32)
      c = lambda i, x: tf.less(i, p)
      b = lambda i, x: (tf.add(i, 1), m(x))
      oi, ox = tf.while_loop(c, b, [i, x])  # ox = v**p * x
      v = m.variables[0]
      dodv = tf.gradients(ox, v)[0]  # d ox / dv = p*v**(p-1) * x
      dodx = tf.gradients(ox, x)[0]  # d ox / dx = v**p
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 1}), [1, 20])
        self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 2}), [2, 40])
        self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 4}), [4, 160])
        # Gradients also use the control flow structures setup earlier.
        # Also check they are working properly.
        self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 1}), [10, 2])
        self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 2}), [40, 4])
        self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 4}), [320, 16])

  # tf.map_fn() is merely a wrapper around tf.while(), but just to be sure... 
Example #6
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testNonResourceVariableInCond(self):
    with tf.Graph().as_default():
      spec = hub.create_module_spec(stateful_non_rv_module_fn)
      m = hub.Module(spec)
      pred = tf_v1.placeholder(tf.bool)
      def true_fn():
        v = m()
        self.assertItemsEqual(v.op.colocation_groups(),
                              [tf.compat.as_bytes("loc:@module/var123")])
        return v
      def false_fn():
        return tf.constant(9.0)
      out = tf.cond(pred, true_fn, false_fn)
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertEqual(sess.run(out, feed_dict={pred: True}), 10.0)
        self.assertEqual(sess.run(out, feed_dict={pred: False}), 9.0) 
Example #7
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testVariableColocationPropagation(self):
    with tf.Graph().as_default():
      spec = hub.create_module_spec(stateful_module_fn_with_colocation)
      m = hub.Module(spec)
      u1 = tf.constant(1, name="u1")
      u2 = tf.constant(2, name="u2")
      with tf_v1.colocate_with(u1), tf_v1.colocate_with(u2):
        x = tf.constant(100.0, name="x")
      y = m(x)
      self.assertItemsEqual(y.op.colocation_groups(),
                            [tf.compat.as_bytes("loc:@module/var123"),
                             tf.compat.as_bytes("loc:@u1"),
                             tf.compat.as_bytes("loc:@u2")])
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertEqual(sess.run(y), 101.0) 
Example #8
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testPartitionedVariables(self):
    with tf.Graph().as_default():
      spec = hub.create_module_spec(
          create_partitioned_variable_module_fn(partitions=3, shape=[7, 3]))
      m = hub.Module(spec, name="test")
      out = m()
      self.assertEqual(len(m.variable_map), 2)
      self.assertEqual(m.variable_map["normal_variable"].name,
                       "test/normal_variable:0")
      self.assertAllEqual([
          variable.name for variable in m.variable_map["partitioned_variable"]
      ], [
          "test/partitioned_variable/part_0:0",
          "test/partitioned_variable/part_1:0",
          "test/partitioned_variable/part_2:0"
      ])
      self.assertAllEqual(  # Check deterministric order (by variable_map key).
          [variable.name for variable in m.variables],
          ["test/normal_variable:0",
           "test/partitioned_variable/part_0:0",
           "test/partitioned_variable/part_1:0",
           "test/partitioned_variable/part_2:0"])
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertAllClose(sess.run(out), 2 * np.ones([7, 3])) 
Example #9
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testUnsupportedCollections(self):

    def module_fn():
      scale = tf_v1.get_variable("x", (), collections=["my_scope"])
      x = tf_v1.placeholder(tf.float32, shape=[None, 3])
      native_module.add_signature("my_func", {"x": x}, {"y": x*scale})

    with self.assertRaises(ValueError) as cm:
      _ = native_module.create_module_spec(module_fn)
      self.assertIn("Unsupported collections in graph", cm)

    with tf.Graph().as_default() as tmp_graph:
      module_fn()
      unsupported_collections = native_module.get_unsupported_collections(
          tmp_graph.get_all_collection_keys())
      self.assertEqual(["my_scope"], unsupported_collections)

    _ = native_module.create_module_spec(
        module_fn, drop_collections=unsupported_collections) 
Example #10
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testLoadTrainableModuleFromFuncDef(self):
    with tf_v1.Session() as sess:
      spec = hub.create_module_spec(stateful_module_fn)
      m = hub.Module(spec, trainable=True)
      x = m()
      step = tf.Variable(0, trainable=False, name="global_step")
      train_op = tf_v1.train.GradientDescentOptimizer(0.40).minimize(
          loss=tf_v1.losses.mean_squared_error(x, [3.1, 3.2, 3.3]),
          global_step=step)
      sess.run(tf_v1.global_variables_initializer())
      for _ in range(50):
        sess.run(train_op)
      got = sess.run(x)
      self.assertAllClose(got, [3.1, 3.2, 3.3])

  # TODO(b/112575006): The following tests verify functionality of function call
  # within a TPU context. Work to generalize this for all function calls is
  # ongoing. 
Example #11
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testModuleWithVariablesAndNoCheckpoint(self):
    with tf.Graph().as_default():
      spec = native_module.create_module_spec(module_with_variables)
      spec._create_impl(name="module", trainable=False, tags=None)
      self.assertAllEqual(
          [x.op.name for x in tf_v1.global_variables()],
          [
              "module/weights",
              "module/partition/part_0",
              "module/partition/part_1",
              "module/partition/part_2",
          ])

      with tf_v1.Session() as session:
        session.run(tf_v1.initializers.global_variables())
        expected_values = [
            [0.0, 0.0, 0.0],
            [0.0, 0.0],
            [0.0],
            [0.0],
        ]
        for a, b in zip(session.run(tf_v1.global_variables()), expected_values):
          self.assertAllEqual(a, b) 
Example #12
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testTPUModuleInitializeOnceWithDefun(self):
    spec = hub.create_module_spec(stateful_random_rv_module_fn)

    @function.Defun()
    def import_computation():
      context = TPUReplicateContext()
      context.Enter()
      m = hub.Module(spec, name="module_", trainable=True)
      return [m(), m()]

    with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
      x = import_computation()
      sess.run(tf_v1.global_variables_initializer())
      got = sess.run(x)
      # Check the values are equal. If the initializer ran on each call,
      # the values would be different.
      self.assertEqual(got[0], got[1]) 
Example #13
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def _testNestedControlFlowModule(self):
    spec = hub.create_module_spec(nested_control_flow_module_fn)
    with tf.Graph().as_default():
      with tf_v1.Session() as sess:
        elems = tf_v1.placeholder(tf.float32, shape=[None])
        a = tf_v1.placeholder(tf.float32)
        m = hub.Module(spec)
        out = m({"elems": elems, "a": a})
        grad = tf.gradients([out], [elems])
        self.assertAllClose(
            sess.run(out, {
                a: 1.1,
                elems: [10, 0, 0.5, 1.2]
            }), 11.2)

        self.assertAllClose(sess.run(grad, {a: 1, elems: [10, 0, 0.5, 1.2]}),
                            [[1.0, 0.0, 0.0, 1.0]]) 
Example #14
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testTPUPruneWithUnusedInput(self):
    spec = hub.create_module_spec(unused_input_module_fn)

    @function.Defun()
    def import_computation(x):
      context = TPUReplicateContext()
      context.Enter()
      m = hub.Module(spec, name="module_", trainable=True)
      return m({
          "x": tf.cast(x, dtype=tf.int64),
          "unused": tf.constant(2, dtype=tf.int64)
      })

    with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
      x = import_computation(5)
      got = sess.run(x)
      self.assertEqual(got, 25) 
Example #15
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testBrittleColocationWithInputsFromMultivaluedOp(self):
    """Tests handling of ambiguous rewrites during module.__call__."""
    spec = hub.create_module_spec(brittle_multivalued_colocation_module_fn)
    with tf.Graph().as_default():
      u = tf.constant([1], name="u")
      with tf_v1.colocate_with(u):
        v = tf.constant([2], name="v")
      w = tf.constant([3], name="w")
      m = hub.Module(spec, name="m")
      # It works if both inputs are mapped to ops with equal colocation groups.
      assert u.op.colocation_groups() == v.op.colocation_groups()
      z = m(dict(x=u, y=v), signature="both")
      self.assertItemsEqual(z.op.colocation_groups(),
                            [tf.compat.as_bytes("loc:@u")])
      # It crashes in the general case.
      assert u.op.colocation_groups() != w.op.colocation_groups()
      with self.assertRaisesRegexp(
          ValueError,
          # In Python 3 (but not 2), colocation groups are lists of bytes,
          # which are formatted with a leading "b" just before the quotes.
          r"(?s)Failed to rewrite .*b?'loc:@m_apply_both_1/split' .*"
          "\[b?'loc:@[uw]'\] vs \[b?'loc:@[wu]'\]"):
        z = m(dict(x=u, y=w), signature="both") 
Example #16
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testModuleWithLayers(self):
    export_path = os.path.join(self.get_temp_dir(), "layers-module")

    sample_input = [[1.0, 2.0], [3.1, 10.0]]

    spec = hub.create_module_spec(layers_module_fn)
    with tf.Graph().as_default():
      m = hub.Module(spec, trainable=False)
      x = tf_v1.placeholder(dtype=tf.float32)
      y = m(x)
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        sample_output = sess.run(y, feed_dict={x: sample_input})
        m.export(export_path, sess)

    with tf.Graph().as_default():
      x = tf_v1.placeholder(dtype=tf.float32)
      y = hub.Module(export_path)(x)

      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        got = sess.run(y, feed_dict={x: sample_input})
        self.assertAllEqual(got, sample_output) 
Example #17
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testWrapFunction(self):
    if not tf.executing_eagerly():
      self.skipTest("Test requires eager.")

    spec = hub.create_module_spec(stateful_rv_with_input_module_fn)

    initializers = []
    def use_module(x, y):
      m = hub.Module(spec, name="module_", trainable=True)
      initializers.append(tf_v1.initializers.global_variables())
      return [m(x), m(y)]

    input_signature = [
        tf.TensorSpec((), tf.float32),
        tf.TensorSpec((), tf.float32),
    ]

    f = tf_v1.wrap_function(use_module, input_signature)
    f.prune([], initializers)()
    self.assertAllEqual(
        [x.numpy() for x in f(9.0, 6.0)],
        [19.0, 16.0]) 
Example #18
Source File: modular_gan.py    From compare_gan with Apache License 2.0 6 votes vote down vote up
def as_module_spec(self):
    """Returns the generator network as TFHub module spec."""
    models = ["gen", "disc"]
    default_batch_size = 64
    batch_sizes = [8, 16, 32, 64]
    if "resnet" in self._architecture:
      # Only ResNet architectures support dynamic batch size.
      batch_sizes.append(None)
      default_batch_size = None
    tags_and_args = [
        (set(), {"model": "gen", "batch_size": default_batch_size})]
    for model, bs in itertools.product(models, batch_sizes):
      tags = {model, "bs{}".format(bs)}
      args = {"model": model, "batch_size": bs}
      tags_and_args.append((tags, args))
    return hub.create_module_spec(
        self._module_fn, tags_and_args=tags_and_args,
        drop_collections=[tf.GraphKeys.MOVING_AVERAGE_VARIABLES]) 
Example #19
Source File: estimator_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def _get_model_fn(register_module=False):
  def _model_fn(features, labels, mode):
    """A model_fn that uses a mock TF-Hub module."""
    del labels

    spec = hub.create_module_spec(text_module_fn)
    embedding = hub.Module(spec)
    if register_module:
      hub.register_module_for_export(embedding, _EXPORT_MODULE_NAME)
    predictions = embedding(features[_TEXT_FEATURE_NAME])
    loss = tf.constant(0.0)

    global_step = tf_v1.train.get_global_step()
    train_op = tf_v1.assign_add(global_step, 1)

    return tf_v1.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op)

  return _model_fn 
Example #20
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testTPUModuleWithWrapFunc(self):
    spec = hub.create_module_spec(stateful_rv_with_input_module_fn)

    def import_computation(first, second):
      context = TPUReplicateContext()
      context.Enter()
      m = hub.Module(spec, trainable=True)
      return [m(first), m(second)]

    with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
      x = tf_v1.wrap_function(
          import_computation,
          [tf.TensorSpec((), tf.float32),
           tf.TensorSpec((), tf.float32)])
      sess.run(tf_v1.global_variables_initializer())
      got = sess.run(x(9.0, 6.0))
      self.assertEqual(got, [19.0, 16.0]) 
Example #21
Source File: keras_layer_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def _save_half_plus_one_hub_module_v1(path):
  """Writes TF1.x hub.Module to compute y = wx + 1, with w trainable."""
  def half_plus_one():
    x = tf.compat.v1.placeholder(shape=(None,1), dtype=tf.float32)
    # Use TF1 native tf.compat.v1.layers instead of tf.keras.layers as they
    # correctly update TF collections, such as REGULARIZATION_LOSS.
    times_w = tf.compat.v1.layers.Dense(
        units=1,
        kernel_initializer=tf.keras.initializers.Constant([[0.5]]),
        kernel_regularizer=tf.keras.regularizers.l2(0.01),
        use_bias=False)
    plus_1 = tf.compat.v1.layers.Dense(
        units=1,
        kernel_initializer=tf.keras.initializers.Constant([[1.0]]),
        bias_initializer=tf.keras.initializers.Constant([1.0]),
        trainable=False)
    y = plus_1(times_w(x))
    hub.add_signature(inputs=x, outputs=y)

  spec = hub.create_module_spec(half_plus_one)
  _export_module_spec_with_init_weights(spec, path) 
Example #22
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testTPUModuleDoesntPruneControlDependencies(self):
    spec = hub.create_module_spec(control_dependency_module_fn)

    @function.Defun()
    def import_computation():
      context = TPUReplicateContext()
      context.Enter()
      m = hub.Module(spec, name="module_", trainable=True)
      return m()

    with tf_v1.Graph().as_default(), tf_v1.Session() as sess:
      x = import_computation()
      got = sess.run(x)
      self.assertEqual(got, 5.0)
      # If the op got pruned, the following get_operation_by_name should fail
      # with a dependency error.
      tf_v1.get_default_graph().get_operation_by_name("module_/dependency_op") 
Example #23
Source File: native_module_test.py    From hub with Apache License 2.0 6 votes vote down vote up
def testModuleSpec(self):
    """This is the general test for ModuleSpec and native_module._ModuleSpec."""
    spec = hub.create_module_spec(attached_messages_module_fn)
    attached_letters = spec.get_attached_message("letters",
                                                 tf_v1.train.BytesList)
    self.assertSequenceEqual(
        attached_letters.value,
        [tf.compat.as_bytes("abc"),
         tf.compat.as_bytes("xyz")])
    attached_numbers = spec.get_attached_message("numbers",
                                                 tf_v1.train.Int64List)
    self.assertSequenceEqual(attached_numbers.value, [42, 69])
    attached_train = spec.get_attached_message("tagged", tf_v1.train.Int64List)
    self.assertSequenceEqual(attached_train.value, [0])
    self.assertIsNone(spec.get_attached_message("bad", tf_v1.train.BytesList))
    with self.assertRaises(KeyError):
      spec.get_attached_message("bad", tf_v1.train.BytesList, required=True) 
Example #24
Source File: native_module_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def testMultipleOutputs(self):
    with tf_v1.Session() as sess:
      spec = hub.create_module_spec(multiple_outputs_module_fn)
      m = hub.Module(spec)
      output = m(tf.constant([2.0]), as_dict=True)
      output1 = output["y"]
      output2 = output["z"]
      sess.run(tf_v1.global_variables_initializer())
      self.assertAllClose(sess.run(output1), [6.0])
      self.assertAllClose(sess.run(output2), [18.0]) 
Example #25
Source File: native_module_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def testBadColocationWithPartialInputsFromMultivaluedOp(self):
    spec = hub.create_module_spec(brittle_multivalued_colocation_module_fn)
    with tf.Graph().as_default():
      u = tf.constant([1], name="u")
      m = hub.Module(spec, name="m")
      with self.assertRaisesRegexp(
          ValueError,
          r"(?s)Failed to rewrite .*b?'loc:@m_apply_partial/split' .*"
          "\[b?'loc:@u'\] vs \[b?'loc:@m_apply_partial/split'\]"):
        z = m(dict(x=u), signature="partial") 
Example #26
Source File: native_module_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def testMultipleApplicationsInDifferentScopes(self):
    with tf.Graph().as_default():
      export_path = os.path.join(self.get_temp_dir(), "module-applied-in-scope")

      spec = hub.create_module_spec(another_stateful_module_fn)
      stateful_module = hub.Module(spec, name="moduleA")
      with tf.name_scope("foo"):
        with tf_v1.variable_scope("bar"):
          times2 = stateful_module(tf.constant([2.0]))
      with tf.name_scope("baz"):
        times3 = stateful_module(tf.constant([3.0]))

      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertAllClose(sess.run(times2), [6.0])
        self.assertAllClose(sess.run(times3), [9.0])
        self.assertEqual(len(stateful_module.variable_map), 1)
        self.assertEqual(
            stateful_module.variable_map["iamtheoneandonly"].name,
            "moduleA/iamtheoneandonly:0")
        stateful_module.export(export_path, sess)

      # Check minimal functionality of the exported module.
    with tf.Graph().as_default():
      stateful_module = hub.Module(export_path, name="moduleB")
      times2 = stateful_module(tf.constant([2.0]))
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertAllClose(sess.run(times2), [6.0]) 
Example #27
Source File: native_module_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def testExportModuleSpec_withWrongShape(self):
    checkpoint_path = self.createCheckpoint(scope="block")
    export_path = os.path.join(self.get_temp_dir(), "module2")

    spec = hub.create_module_spec(lambda: self.module_fn(dim=20))
    with self.assertRaisesRegexp(ValueError, "doesn't match with shape of"):
      spec.export(export_path,
                  checkpoint_path=checkpoint_path,
                  name_transform_fn=lambda x: "block/" + x) 
Example #28
Source File: native_module_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def testExportModuleSpec_withWrongScope(self):
    checkpoint_path = self.createCheckpoint("block2")
    export_path = os.path.join(self.get_temp_dir(), "module3")

    spec = hub.create_module_spec(self.module_fn)
    with self.assertRaisesRegexp(ValueError, "bias is not found in"):
      spec.export(export_path,
                  checkpoint_path=checkpoint_path,
                  name_transform_fn=lambda x: "block/" + x) 
Example #29
Source File: feature_column_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def setUp(self):
    self.spec = hub.create_module_spec(text_module_fn) 
Example #30
Source File: native_module_test.py    From hub with Apache License 2.0 5 votes vote down vote up
def testSeparateCopies(self):
    """Mutating returned objects does not affect future returned values."""
    spec = hub.create_module_spec(attached_messages_module_fn)
    attached_numbers = spec.get_attached_message("numbers",
                                                 tf_v1.train.Int64List)
    self.assertSequenceEqual(attached_numbers.value, [42, 69])
    attached_numbers.Clear()
    self.assertSequenceEqual(attached_numbers.value, [])
    attached_numbers = spec.get_attached_message("numbers",
                                                 tf_v1.train.Int64List)
    self.assertSequenceEqual(attached_numbers.value, [42, 69])