Python tensorflow.python.ops.math_ops.add() Examples

The following are 30 code examples of tensorflow.python.ops.math_ops.add(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.math_ops , or try the search function .
Example #1
Source File: control_flow_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def __init__(self, parallel_iterations=10, back_prop=True, swap_memory=False,
               name="while_context", grad_state=None, context_def=None,
               import_scope=None):
    """"Creates a `WhileContext`.

    Args:
      parallel_iterations: The number of iterations allowed to run in parallel.
      back_prop: Whether backprop is enabled for this while loop.
      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
      name: Optional name prefix for the returned tensors.
      grad_state: The gradient loop state.
      context_def: Optional `WhileContextDef` protocol buffer to initialize
        the `Whilecontext` python object from.
      import_scope: Optional `string`. Name scope to add. Only used when
        initialing from protocol buffer.
    """
    if context_def:
      self._init_from_proto(context_def, import_scope=import_scope)
    else:
      ControlFlowContext.__init__(self)
      self._init_from_args(parallel_iterations, back_prop, swap_memory,
                           name)
    # The gradient loop state.
    self._grad_state = grad_state 
Example #2
Source File: control_flow_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def AddValue(self, val):
    """Add `val` to the current context and its outer context recursively."""
    if val.name in self._values:
      # Use the real value if it comes from outer context. This is needed in
      # particular for nested conds.
      result = self._external_values.get(val.name)
      result = val if result is None else result
    else:
      result = val
      self._values.add(val.name)
      if self._outer_context:
        result = self._outer_context.AddValue(val)
        self._values.add(result.name)
      with ops.control_dependencies(None):
        result = _SwitchRefOrTensor(result, self._pred)[self._branch]
      result.op.graph.prevent_fetching(result.op)
      # pylint: disable=protected-access
      result.op._set_control_flow_context(self)
      # pylint: enable=protected-access

      self._values.add(result.name)
      self._external_values[val.name] = result
    return result 
Example #3
Source File: control_flow_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def _init_values_from_proto(self, values_def, import_scope=None):
    """Initializes values and external_values from `ValuesDef` protocol buffer.

    Args:
      values_def: `ValuesDef` protocol buffer.
      import_scope: Optional `string`. Name scope to add.
    """
    assert isinstance(values_def, control_flow_pb2.ValuesDef)
    self._values = set(values_def.values)
    g = ops.get_default_graph()
    self._external_values = {}
    for k, v in values_def.external_values.items():
      self._external_values[k] = g.as_graph_element(
          ops.prepend_name_scope(v, import_scope))
    op_names = set([op.split(":")[0]
                    for op in self._values - set(self._external_values)])
    for op in op_names:
      # pylint: disable=protected-access
      g.as_graph_element(ops.prepend_name_scope(
          op, import_scope))._set_control_flow_context(self)
      # pylint: enable=protected-access 
Example #4
Source File: session_debug_testlib.py    From lambda-packs with MIT License 6 votes vote down vote up
def testDebugCondWatchingWholeGraphWorks(self):
    with session.Session() as sess:
      x = variables.Variable(10.0, name="x")
      y = variables.Variable(20.0, name="y")
      cond = control_flow_ops.cond(
          x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))

      sess.run(variables.global_variables_initializer())

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(run_options,
                              sess.graph,
                              debug_urls=self._debug_urls())
      run_metadata = config_pb2.RunMetadata()
      self.assertEqual(
          21, sess.run(cond, options=run_options, run_metadata=run_metadata))

      dump = debug_data.DebugDumpDir(
          self._dump_root, partition_graphs=run_metadata.partition_graphs)
      self.assertAllClose(
          [21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity")) 
Example #5
Source File: train_crnn.py    From 2019-CCF-BDCI-OCR-MCZJ-OCR-IdentificationIDElement with MIT License 6 votes vote down vote up
def distort_color(image, color_ordering=0, scope=None):
    """
    随机进行图像增强(亮度、对比度操作)
    :param image: 输入图片
    :param color_ordering:模式
    :param scope: 命名空间
    :return: 增强后的图片
    """
    with tf.name_scope(scope, 'distort_color', [image]):
        if color_ordering == 0:  # 模式0.先调整亮度,再调整对比度
            rand_temp = random_ops.random_uniform([], -55, 20, seed=None) # [-70, 30] for generate img, [-50, 20] for true img 
            image = math_ops.add(image, math_ops.cast(rand_temp, dtypes.float32))
            image = tf.image.random_contrast(image, lower=0.45, upper=1.5) # [0.3, 1.75] for generate img, [0.45, 1.5] for true img 
        else:
            image = tf.image.random_contrast(image, lower=0.45, upper=1.5)
            rand_temp = random_ops.random_uniform([], -55, 30, seed=None)
            image = math_ops.add(image, math_ops.cast(rand_temp, dtypes.float32))

        # The random_* ops do not necessarily clamp.
        print(color_ordering)
        return tf.clip_by_value(image, 0.0, 255.0)  # 限定在0-255
########################################################################## 
Example #6
Source File: control_flow_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def _ProcessOutputTensor(self, val):
    """Process an output tensor of a conditional branch."""
    real_val = val
    if val.name not in self._values:
      # Handle the special case of lambda: x
      self._values.add(val.name)
      if self._outer_context:
        real_val = self._outer_context.AddValue(val)
        self._values.add(real_val.name)
      real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
      self._external_values[val.name] = real_val
    else:
      external_val = self._external_values.get(val.name)
      if external_val is not None:
        real_val = external_val
    return real_val 
Example #7
Source File: control_flow_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def __init__(self, parallel_iterations=10, back_prop=True, swap_memory=False,
               name="while_context", grad_state=None, context_def=None,
               import_scope=None):
    """"Creates a `WhileContext`.

    Args:
      parallel_iterations: The number of iterations allowed to run in parallel.
      back_prop: Whether backprop is enabled for this while loop.
      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
      name: Optional name prefix for the returned tensors.
      grad_state: The gradient loop state.
      context_def: Optional `WhileContextDef` protocol buffer to initialize
        the `Whilecontext` python object from.
      import_scope: Optional `string`. Name scope to add. Only used when
        initialing from protocol buffer.
    """
    if context_def:
      self._init_from_proto(context_def, import_scope=import_scope)
    else:
      ControlFlowContext.__init__(self)
      self._init_from_args(parallel_iterations, back_prop, swap_memory,
                           name)
    # The gradient loop state.
    self._grad_state = grad_state 
Example #8
Source File: control_flow_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def _InitializeValues(self, values):
    """Makes the values known to this context."""
    self._values = set()
    for x in values:
      if isinstance(x, ops.Tensor):
        self._values.add(x.name)
      else:
        self._values.add(x.values.name)
        self._values.add(x.indices.name)
        if isinstance(x, ops.IndexedSlices):
          dense_shape = x.dense_shape
        elif isinstance(x, sparse_tensor.SparseTensor):
          dense_shape = x.dense_shape
        else:
          raise TypeError("Type %s not supported" % type(x))
        if dense_shape is not None:
          self._values.add(dense_shape.name) 
Example #9
Source File: feature_column.py    From lambda-packs with MIT License 6 votes vote down vote up
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
    """Returns a `Tensor`.

    The output of this function will be used by model-builder-functions. For
    example the pseudo code of `input_layer` will be like:

    ```python
    def input_layer(features, feature_columns, ...):
      outputs = [fc._get_dense_tensor(...) for fc in feature_columns]
      return tf.concat(outputs)
    ```

    Args:
      inputs: A `_LazyBuilder` object to access inputs.
      weight_collections: List of graph collections to which Variables (if any
        will be created) are added.
      trainable: If `True` also add variables to the graph collection
        `GraphKeys.TRAINABLE_VARIABLES` (see ${tf.Variable}).

    Returns:
      `Tensor` of shape [batch_size] + `_variable_shape`.
    """
    pass 
Example #10
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testListTensorFilterByOpTypeRegex(self):
    out = self._registry.dispatch_command("list_tensors",
                                          ["--op_type_filter", "Identity"])
    assert_listed_tensors(
        self,
        out, ["simple_mul_add/u/read:0", "simple_mul_add/v/read:0"],
        ["Identity", "Identity"],
        op_type_regex="Identity")

    out = self._registry.dispatch_command("list_tensors",
                                          ["-t", "(Add|MatMul)"])
    assert_listed_tensors(
        self,
        out, ["simple_mul_add/add:0", "simple_mul_add/matmul:0"],
        ["Add", "MatMul"],
        op_type_regex="(Add|MatMul)")
    check_main_menu(self, out, list_tensors_enabled=False) 
Example #11
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testNodeInfoShowDumps(self):
    node_name = "simple_mul_add/matmul"
    out = self._registry.dispatch_command("node_info", ["-d", node_name])

    assert_node_attribute_lines(
        self,
        out,
        node_name,
        "MatMul",
        self._main_device, [("Identity", "simple_mul_add/u/read"),
                            ("Identity", "simple_mul_add/v/read")], [],
        [("Add", "simple_mul_add/add"), ("Add", "simple_mul_add/add")], [],
        num_dumped_tensors=1)
    check_main_menu(
        self,
        out,
        list_tensors_enabled=True,
        list_inputs_node_name=node_name,
        print_tensor_node_name=node_name,
        list_outputs_node_name=node_name)
    check_menu_item(self, out, 16,
                    len(out.lines[16]) - len(out.lines[16].strip()),
                    len(out.lines[16]), "pt %s:0 -n 0" % node_name) 
Example #12
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testNodeInfoShowStackTraceUnavailableIsIndicated(self):
    self._debug_dump.set_python_graph(None)

    node_name = "simple_mul_add/matmul"
    out = self._registry.dispatch_command("node_info", ["-t", node_name])

    assert_node_attribute_lines(
        self,
        out,
        node_name,
        "MatMul",
        self._main_device, [("Identity", "simple_mul_add/u/read"),
                            ("Identity", "simple_mul_add/v/read")], [],
        [("Add", "simple_mul_add/add"), ("Add", "simple_mul_add/add")], [],
        show_stack_trace=True, stack_trace_available=False)
    check_main_menu(
        self,
        out,
        list_tensors_enabled=True,
        list_inputs_node_name=node_name,
        print_tensor_node_name=node_name,
        list_outputs_node_name=node_name) 
Example #13
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testNodeInfoShowStackTraceAvailableWorks(self):
    self._debug_dump.set_python_graph(self._sess.graph)

    node_name = "simple_mul_add/matmul"
    out = self._registry.dispatch_command("node_info", ["-t", node_name])

    assert_node_attribute_lines(
        self,
        out,
        node_name,
        "MatMul",
        self._main_device, [("Identity", "simple_mul_add/u/read"),
                            ("Identity", "simple_mul_add/v/read")], [],
        [("Add", "simple_mul_add/add"), ("Add", "simple_mul_add/add")], [],
        show_stack_trace=True, stack_trace_available=True)
    check_main_menu(
        self,
        out,
        list_tensors_enabled=True,
        list_inputs_node_name=node_name,
        print_tensor_node_name=node_name,
        list_outputs_node_name=node_name) 
Example #14
Source File: saved_model_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testNoOverwrite(self):
    export_dir = os.path.join(test.get_temp_dir(), "test_no_overwrite")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # An attempt to create another builder with the same export directory should
    # result in an assertion error.
    self.assertRaises(AssertionError, saved_model_builder.SavedModelBuilder,
                      export_dir) 
Example #15
Source File: control_flow_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _init_values_from_proto(self, values_def, import_scope=None):
    """Initializes values and external_values from `ValuesDef` protocol buffer.

    Args:
      values_def: `ValuesDef` protocol buffer.
      import_scope: Optional `string`. Name scope to add.
    """
    assert isinstance(values_def, control_flow_pb2.ValuesDef)
    self._values = set(values_def.values)
    g = ops.get_default_graph()
    self._external_values = {}
    for k, v in values_def.external_values.items():
      self._external_values[k] = g.as_graph_element(v)
    op_names = set([op.split(":")[0]
                    for op in self._values - set(self._external_values)])
    for op in op_names:
      # pylint: disable=protected-access
      g.as_graph_element(ops.prepend_name_scope(
          op, import_scope))._set_control_flow_context(self)
      # pylint: enable=protected-access 
Example #16
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testNodeInfoByNodeName(self):
    node_name = "simple_mul_add/matmul"
    out = self._registry.dispatch_command("node_info", [node_name])

    recipients = [("Add", "simple_mul_add/add"), ("Add", "simple_mul_add/add")]

    assert_node_attribute_lines(self, out, node_name, "MatMul",
                                self._main_device,
                                [("Identity", "simple_mul_add/u/read"),
                                 ("Identity", "simple_mul_add/v/read")], [],
                                recipients, [])
    check_main_menu(
        self,
        out,
        list_tensors_enabled=True,
        list_inputs_node_name=node_name,
        print_tensor_node_name=node_name,
        list_outputs_node_name=node_name)

    # Verify that the node name is bold in the first line.
    self.assertEqual(
        [(len(out.lines[0]) - len(node_name), len(out.lines[0]), "bold")],
        out.font_attr_segs[0]) 
Example #17
Source File: control_flow_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _init_from_proto(self, context_def, import_scope=None):
    """Creates a new `CondContext` from protocol buffer.

    Args:
      context_def: `CondContextDef` protocol buffer.
      import_scope: Optional `string`. Name scope to add.
    """
    assert isinstance(context_def, control_flow_pb2.CondContextDef)
    # Create from context_def.
    g = ops.get_default_graph()
    self._name = ops.prepend_name_scope(
        context_def.context_name, import_scope)
    self._pred = g.as_graph_element(ops.prepend_name_scope(
        context_def.pred_name, import_scope))
    self._pivot = g.as_graph_element(ops.prepend_name_scope(
        context_def.pivot_name, import_scope))
    self._branch = context_def.branch
    super(CondContext, self).__init__(values_def=context_def.values_def,
                                      import_scope=import_scope) 
Example #18
Source File: control_flow_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def AddValue(self, val):
    """Add `val` to the current context and its outer context recursively."""
    if val.name in self._values:
      # Use the real value if it comes from outer context. This is needed in
      # particular for nested conds.
      result = self._external_values.get(val.name)
      result = val if result is None else result
    else:
      result = val
      self._values.add(val.name)
      if self._outer_context:
        result = self._outer_context.AddValue(val)
        self._values.add(result.name)
      with ops.control_dependencies(None):
        result = _SwitchRefOrTensor(result, self._pred)[self._branch]
      result.op.graph.prevent_fetching(result.op)
      # pylint: disable=protected-access
      result.op._set_control_flow_context(self)
      # pylint: enable=protected-access

      self._values.add(result.name)
      self._external_values[val.name] = result
    return result 
Example #19
Source File: session_debug_testlib.py    From lambda-packs with MIT License 6 votes vote down vote up
def testDebugWhileLoopWatchingWholeGraphWorks(self):
    with session.Session() as sess:
      loop_body = lambda i: math_ops.add(i, 2)
      loop_cond = lambda i: math_ops.less(i, 16)

      i = constant_op.constant(10, name="i")
      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(run_options,
                              sess.graph,
                              debug_urls=self._debug_urls())
      run_metadata = config_pb2.RunMetadata()
      self.assertEqual(
          16, sess.run(loop, options=run_options, run_metadata=run_metadata))

      dump = debug_data.DebugDumpDir(
          self._dump_root, partition_graphs=run_metadata.partition_graphs)

      self.assertEqual(
          [[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity"))
      self.assertEqual(
          [[12], [14], [16]],
          dump.get_tensors("while/NextIteration", 0, "DebugIdentity")) 
Example #20
Source File: control_flow_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _ProcessOutputTensor(self, val):
    """Process an output tensor of a conditional branch."""
    real_val = val
    if val.name not in self._values:
      # Handle the special case of lambda: x
      self._values.add(val.name)
      if self._outer_context:
        real_val = self._outer_context.AddValue(val)
        self._values.add(real_val.name)
      real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
      self._external_values[val.name] = real_val
    else:
      external_val = self._external_values.get(val.name)
      if external_val is not None:
        real_val = external_val
    return real_val 
Example #21
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testListTensorsInOpTypeOrderWorks(self):
    # Use shorthand alias for the command prefix.
    out = self._registry.dispatch_command("lt", ["-s", "op_type"])
    assert_listed_tensors(
        self,
        out, [
            "simple_mul_add/u:0", "simple_mul_add/v:0",
            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
        ],
        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
        sort_by="op_type",
        reverse=False)
    check_main_menu(self, out, list_tensors_enabled=False) 
Example #22
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testListTensorsInReverseTimeOrderWorks(self):
    # Use shorthand alias for the command prefix.
    out = self._registry.dispatch_command("lt", ["-s", "timestamp", "-r"])
    assert_listed_tensors(
        self,
        out, [
            "simple_mul_add/u:0", "simple_mul_add/v:0",
            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
        ],
        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
        sort_by="timestamp",
        reverse=True)
    check_main_menu(self, out, list_tensors_enabled=False) 
Example #23
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testListTensorsInTensorNameOrderWorks(self):
    # Use shorthand alias for the command prefix.
    out = self._registry.dispatch_command("lt", ["-s", "tensor_name"])
    assert_listed_tensors(
        self,
        out, [
            "simple_mul_add/u:0", "simple_mul_add/v:0",
            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
        ],
        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
        sort_by="tensor_name",
        reverse=False)
    check_main_menu(self, out, list_tensors_enabled=False) 
Example #24
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testListTensorsInDumpSizeOrderWorks(self):
    out = self._registry.dispatch_command("lt", ["-s", "dump_size"])
    assert_listed_tensors(
        self,
        out, [
            "simple_mul_add/u:0", "simple_mul_add/v:0",
            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
        ],
        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
        sort_by="dump_size")
    check_main_menu(self, out, list_tensors_enabled=False) 
Example #25
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testListTensorsInReverseTensorNameOrderWorks(self):
    # Use shorthand alias for the command prefix.
    out = self._registry.dispatch_command("lt", ["-s", "tensor_name", "-r"])
    assert_listed_tensors(
        self,
        out, [
            "simple_mul_add/u:0", "simple_mul_add/v:0",
            "simple_mul_add/u/read:0", "simple_mul_add/v/read:0",
            "simple_mul_add/matmul:0", "simple_mul_add/add:0"
        ],
        ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"],
        sort_by="tensor_name",
        reverse=True)
    check_main_menu(self, out, list_tensors_enabled=False) 
Example #26
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testListTensorFilterByNodeNameRegexAndOpTypeRegex(self):
    out = self._registry.dispatch_command(
        "list_tensors", ["-t", "(Add|MatMul)", "-n", ".*add$"])
    assert_listed_tensors(
        self,
        out, ["simple_mul_add/add:0"], ["Add"],
        node_name_regex=".*add$",
        op_type_regex="(Add|MatMul)")
    check_main_menu(self, out, list_tensors_enabled=False) 
Example #27
Source File: image_ops_impl.py    From lambda-packs with MIT License 5 votes vote down vote up
def adjust_brightness(image, delta):
  """Adjust the brightness of RGB or Grayscale images.

  This is a convenience method that converts an RGB image to float
  representation, adjusts its brightness, and then converts it back to the
  original data type. If several adjustments are chained it is advisable to
  minimize the number of redundant conversions.

  The value `delta` is added to all components of the tensor `image`. Both
  `image` and `delta` are converted to `float` before adding (and `image` is
  scaled appropriately if it is in fixed-point representation). For regular
  images, `delta` should be in the range `[0,1)`, as it is added to the image in
  floating point representation, where pixel values are in the `[0,1)` range.

  Args:
    image: A tensor.
    delta: A scalar. Amount to add to the pixel values.

  Returns:
    A brightness-adjusted tensor of the same shape and type as `image`.
  """
  with ops.name_scope(None, 'adjust_brightness', [image, delta]) as name:
    image = ops.convert_to_tensor(image, name='image')
    # Remember original dtype to so we can convert back if needed
    orig_dtype = image.dtype
    flt_image = convert_image_dtype(image, dtypes.float32)

    adjusted = math_ops.add(flt_image,
                            math_ops.cast(delta, dtypes.float32),
                            name=name)

    return convert_image_dtype(adjusted, orig_dtype, saturate=True) 
Example #28
Source File: metrics_impl.py    From lambda-packs with MIT License 5 votes vote down vote up
def _select_class_id(ids, selected_id):
  """Filter all but `selected_id` out of `ids`.

  Args:
    ids: `int64` `Tensor` or `SparseTensor` of IDs.
    selected_id: Int id to select.

  Returns:
    `SparseTensor` of same dimensions as `ids`. This contains only the entries
    equal to `selected_id`.
  """
  ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
  if isinstance(ids, sparse_tensor.SparseTensor):
    return sparse_ops.sparse_retain(
        ids, math_ops.equal(ids.values, selected_id))

  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
  # tf.equal and tf.reduce_any?

  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
  ids_last_dim = array_ops.size(ids_shape) - 1
  filled_selected_id_shape = math_ops.reduced_shape(
      ids_shape, array_ops.reshape(ids_last_dim, [1]))

  # Intersect `ids` with the selected ID.
  filled_selected_id = array_ops.fill(
      filled_selected_id_shape, math_ops.to_int64(selected_id))
  result = sets.set_intersection(filled_selected_id, ids)
  return sparse_tensor.SparseTensor(
      indices=result.indices, values=result.values, dense_shape=ids_shape) 
Example #29
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def testListTensors(self):
    # Use shorthand alias for the command prefix.
    out = self._registry.dispatch_command("lt", [])

    assert_listed_tensors(self, out, [
        "simple_mul_add/u:0", "simple_mul_add/v:0", "simple_mul_add/u/read:0",
        "simple_mul_add/v/read:0", "simple_mul_add/matmul:0",
        "simple_mul_add/add:0"
    ], ["VariableV2", "VariableV2", "Identity", "Identity", "MatMul", "Add"])

    # Check the main menu.
    check_main_menu(self, out, list_tensors_enabled=False) 
Example #30
Source File: metrics_impl.py    From lambda-packs with MIT License 5 votes vote down vote up
def _local_variable(initial_value, validate_shape=True, name=None):
  """Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection.

  Args:
    initial_value: See variables.Variable.__init__.
    validate_shape: See variables.Variable.__init__.
    name: See variables.Variable.__init__.
  Returns:
    New variable.
  """
  return variable_scope.variable(
      initial_value, trainable=False,
      collections=[ops.GraphKeys.LOCAL_VARIABLES],
      validate_shape=validate_shape, name=name)