Python tensorflow.core.framework.graph_pb2.GraphDef() Examples

The following are 30 code examples of tensorflow.core.framework.graph_pb2.GraphDef(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.core.framework.graph_pb2 , or try the search function .
Example #1
Source File: utils.py    From speech_separation with MIT License 7 votes vote down vote up
def load_graph(graph_path,tensorboard=False,**kwargs):
    '''
    :param graph_filename: the path of the pb file
    :return: tensorflow graph
    '''
    with gfile.FastGFile(graph_path,'rb') as f:
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def,name="")

    if tensorboard:
        writer = tf.summary.FileWriter("log/")
        writer.add_graph(graph)

    return graph 
Example #2
Source File: tf.py    From training_results_v0.6 with Apache License 2.0 6 votes vote down vote up
def ProcessGraphDefParam(graph_def):
    """Type-checks and possibly canonicalizes `graph_def`.

    Parameters
    ----------
    graph_def : Obj
        tensorflow graph definition.

    Returns
    -------
    graph_def : Obj
        tensorflow graph devinition

    """

    if not isinstance(graph_def, graph_pb2.GraphDef):
        # `graph_def` could be a dynamically-created message, so try a duck-typed
        # approach
        try:
            old_graph_def = graph_def
            graph_def = graph_pb2.GraphDef()
            graph_def.MergeFrom(old_graph_def)
        except TypeError:
            raise TypeError('graph_def must be a GraphDef proto.')
    return graph_def 
Example #3
Source File: print_selective_registration_header_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testAll(self):
    default_ops = 'all'
    graphs = [
        text_format.Parse(d, graph_pb2.GraphDef())
        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
    ]
    ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)

    header = print_selective_registration_header.get_header(ops_and_kernels,
                                                            default_ops)
    self.assertListEqual(
        [
            '#ifndef OPS_TO_REGISTER',  #
            '#define OPS_TO_REGISTER',  #
            '#define SHOULD_REGISTER_OP(op) true',  #
            '#define SHOULD_REGISTER_OP_KERNEL(clz) true',  #
            '#define SHOULD_REGISTER_OP_GRADIENT true',  #
            '#endif'
        ],
        header.split('\n')) 
Example #4
Source File: event_accumulator.py    From lambda-packs with MIT License 6 votes vote down vote up
def Graph(self):
    """Return the graph definition, if there is one.

    If the graph is stored directly, return that.  If no graph is stored
    directly but a metagraph is stored containing a graph, return that.

    Raises:
      ValueError: If there is no graph for this run.

    Returns:
      The `graph_def` proto.
    """
    graph = graph_pb2.GraphDef()
    if self._graph is not None:
      graph.ParseFromString(self._graph)
      return graph
    raise ValueError('There is no graph in this EventAccumulator') 
Example #5
Source File: event_accumulator.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def Graph(self):
    """Return the graph definition, if there is one.

    If the graph is stored directly, return that.  If no graph is stored
    directly but a metagraph is stored containing a graph, return that.

    Raises:
      ValueError: If there is no graph for this run.

    Returns:
      The `graph_def` proto.
    """
    graph = graph_pb2.GraphDef()
    if self._graph is not None:
      graph.ParseFromString(self._graph)
      return graph
    raise ValueError('There is no graph in this EventAccumulator') 
Example #6
Source File: meta_graph_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testStrippedOpListRecursiveFunctions(self):
    # The function module doesn't support recursive functions, so we build a
    # recursive function situation by ourselves: A calls B calls A and Const.
    graph = graph_pb2.GraphDef()
    a = graph.library.function.add()
    b = graph.library.function.add()
    a.signature.name = "A"
    b.signature.name = "B"
    a.node.add().op = "B"
    b.node.add().op = "Const"
    b.node.add().op = "A"

    # Use A in the graph
    graph.node.add().op = "A"

    # The stripped op list should contain just Const.
    op_list = tf.contrib.util.stripped_op_list_for_graph(graph)
    self.assertEqual(["Const"], [op.name for op in op_list.op]) 
Example #7
Source File: ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def as_graph_def(self, from_version=None, add_shapes=False):
    """Returns a serialized `GraphDef` representation of this graph.

    The serialized `GraphDef` can be imported into another `Graph`
    (using @{tf.import_graph_def}) or used with the
    [C++ Session API](../../api_docs/cc/index.md).

    This method is thread-safe.

    Args:
      from_version: Optional.  If this is set, returns a `GraphDef`
        containing only the nodes that were added to this graph since
        its `version` property had the given value.
      add_shapes: If true, adds an "_output_shapes" list attr to each
        node with the inferred shapes of each of its outputs.

    Returns:
      A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
      protocol buffer.

    Raises:
      ValueError: If the `graph_def` would be too large.
    """
    result, _ = self._as_graph_def(from_version, add_shapes)
    return result 
Example #8
Source File: ops.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def get_stats_for_node_def(graph, node, statistic_type):
  """Looks up the node's statistics function in the registry and calls it.

  This function takes a Graph object and a NodeDef from a GraphDef, and if
  there's an associated statistics method, calls it and returns a result. If no
  function has been registered for the particular node type, it returns an empty
  statistics object.

  Args:
    graph: A Graph object that's been set up with the node's graph.
    node: A NodeDef describing the operator.
    statistic_type: A string identifying the statistic we're interested in.
  Returns:
    An OpStats object containing information about resource usage.
  """

  try:
    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
    result = stats_func(graph, node)
  except LookupError:
    result = OpStats(statistic_type)
  return result 
Example #9
Source File: ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def get_stats_for_node_def(graph, node, statistic_type):
  """Looks up the node's statistics function in the registry and calls it.

  This function takes a Graph object and a NodeDef from a GraphDef, and if
  there's an associated statistics method, calls it and returns a result. If no
  function has been registered for the particular node type, it returns an empty
  statistics object.

  Args:
    graph: A Graph object that's been set up with the node's graph.
    node: A NodeDef describing the operator.
    statistic_type: A string identifying the statistic we're interested in.
  Returns:
    An OpStats object containing information about resource usage.
  """

  try:
    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
    result = stats_func(graph, node)
  except LookupError:
    result = OpStats(statistic_type)
  return result 
Example #10
Source File: all_models_to_tensorboard.py    From realtime_object_detection with MIT License 6 votes vote down vote up
def create_tfevent_from_pb(model,optimized=False):
    print("> creating tfevent of model: {}".format(model))

    if optimized:
        model_path=ROOT_DIR+'/models/{}/optimized_inference_graph.pb'.format(model)
        log_dir=ROOT_DIR+'/models/{}/log_opt/'.format(model)
    else:
        model_path=ROOT_DIR+'/models/{}/frozen_inference_graph.pb'.format(model)
        log_dir=ROOT_DIR+'/models/{}/log/'.format(model)

    with session.Session(graph=ops.Graph()) as sess:
        with gfile.FastGFile(model_path, "rb") as f:
          graph_def = graph_pb2.GraphDef()
          graph_def.ParseFromString(f.read())
          importer.import_graph_def(graph_def)
        pb_visual_writer = summary.FileWriter(log_dir)
        pb_visual_writer.add_graph(sess.graph)
    print("> Model {} Imported. \nVisualize by running: \
    tensorboard --logdir={}".format(model_path, log_dir))

# Gather all Model Names in models/ 
Example #11
Source File: tensorflow.py    From blueoil with Apache License 2.0 6 votes vote down vote up
def read(self, pb_path: str) -> Graph:
        """Read TF file and load model.

        Args:
            pb_path (str): Path to TF file

        Returns:
            Model: Loaded model

        """

        # load tensorflow model
        graph_def = graph_pb2.GraphDef()
        try:
            f = open(path.abspath(pb_path), "rb")
            graph_def.ParseFromString(f.read())
            f.close()
        except IOError:
            print("Could not open file. Creating a new one.")

        # import graph
        graph = Importer.make_graph(graph_def)

        return graph 
Example #12
Source File: print_selective_registration_header_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testAll(self):
    default_ops = 'all'
    graphs = [
        text_format.Parse(d, graph_pb2.GraphDef())
        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
    ]
    ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)

    header = print_selective_registration_header.get_header(ops_and_kernels,
                                                            default_ops)
    self.assertListEqual(
        [
            '#ifndef OPS_TO_REGISTER',  #
            '#define OPS_TO_REGISTER',  #
            '#define SHOULD_REGISTER_OP(op) true',  #
            '#define SHOULD_REGISTER_OP_KERNEL(clz) true',  #
            '#define SHOULD_REGISTER_OP_GRADIENT true',  #
            '#endif'
        ],
        header.split('\n')) 
Example #13
Source File: graph_to_dot.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def main(unused_args):
  if not gfile.Exists(FLAGS.graph):
    print("Input graph file '" + FLAGS.graph + "' does not exist!")
    return -1

  graph = graph_pb2.GraphDef()
  with open(FLAGS.graph, "r") as f:
    if FLAGS.input_binary:
      graph.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), graph)

  with open(FLAGS.dot_output, "wb") as f:
    print("digraph graphname {", file=f)
    for node in graph.node:
      output_name = node.name
      print("  \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f)
      for input_full_name in node.input:
        parts = input_full_name.split(":")
        input_name = re.sub(r"^\^", "", parts[0])
        print("  \"" + input_name + "\" -> \"" + output_name + "\";", file=f)
    print("}", file=f)
  print("Created DOT file '" + FLAGS.dot_output + "'.") 
Example #14
Source File: selective_registration_header_lib.py    From lambda-packs with MIT License 6 votes vote down vote up
def get_header(graphs,
               proto_fileformat='rawproto',
               default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'):
  """Computes a header for use with tensorflow SELECTIVE_REGISTRATION.

  Args:
    graphs: a list of paths to GraphDef files to include.
    proto_fileformat: optional format of proto file, either 'textproto' or
      'rawproto' (default).
    default_ops: optional comma-separated string of operator:kernel pairs to
      always include implementation for. Pass 'all' to have all operators and
      kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'.
  Returns:
    the string of the header that should be written as ops_to_register.h.
  """
  ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops)
  if not ops_and_kernels:
    print('Error reading graph!')
    return 1

  return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all') 
Example #15
Source File: event_accumulator.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def Graph(self):
    """Return the graph definition, if there is one.

    If the graph is stored directly, return that.  If no graph is stored
    directly but a metagraph is stored containing a graph, return that.

    Raises:
      ValueError: If there is no graph for this run.

    Returns:
      The `graph_def` proto.
    """
    graph = graph_pb2.GraphDef()
    if self._graph is not None:
      graph.ParseFromString(self._graph)
      return graph
    raise ValueError('There is no graph in this EventAccumulator') 
Example #16
Source File: ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def as_graph_def(self, from_version=None, add_shapes=False):
    """Returns a serialized `GraphDef` representation of this graph.

    The serialized `GraphDef` can be imported into another `Graph`
    (using [`import_graph_def()`](#import_graph_def)) or used with the
    [C++ Session API](../../api_docs/cc/index.md).

    This method is thread-safe.

    Args:
      from_version: Optional.  If this is set, returns a `GraphDef`
        containing only the nodes that were added to this graph since
        its `version` property had the given value.
      add_shapes: If true, adds an "_output_shapes" list attr to each
        node with the inferred shapes of each of its outputs.

    Returns:
      A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
      protocol buffer.

    Raises:
      ValueError: If the `graph_def` would be too large.
    """
    result, _ = self._as_graph_def(from_version, add_shapes)
    return result 
Example #17
Source File: backend_tf.py    From inference with Apache License 2.0 6 votes vote down vote up
def load(self, model_path, inputs=None, outputs=None):
        # there is no input/output meta data i the graph so it need to come from config.
        if not inputs:
            raise ValueError("BackendTensorflow needs inputs")
        if not outputs:
            raise ValueError("BackendTensorflow needs outputs")
        self.outputs = outputs
        self.inputs = inputs

        # TODO: support checkpoint and saved_model formats?
        graph_def = graph_pb2.GraphDef()
        with open(model_path, "rb") as f:
            graph_def.ParseFromString(f.read())
        g = tf.compat.v1.import_graph_def(graph_def, name='')
        self.sess = tf.compat.v1.Session(graph=g)
        return self 
Example #18
Source File: test_util_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def testAssertProtoEqualsStr(self):

    graph_str = "node { name: 'w1' op: 'params' }"
    graph_def = graph_pb2.GraphDef()
    text_format.Merge(graph_str, graph_def)

    # test string based comparison
    self.assertProtoEquals(graph_str, graph_def)

    # test original comparison
    self.assertProtoEquals(graph_def, graph_def) 
Example #19
Source File: meta_graph.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def stripped_op_list_for_graph(graph_def):
  """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.

  Raises:
    ValueError: If an unregistered op is used.
  """
  # This is the Python equivalent of StrippedOpListForGraph in C++.
  # Unfortunately, since the Python op registry can differ from that in C++, we
  # can't remove the duplication using swig (at least naively).
  # TODO(irving): Support taking graphs directly.

  used_ops = ops_used_by_graph_def(graph_def)

  # Verify that all used ops are registered.
  registered_ops = op_def_registry.get_registered_ops()
  # These internal ops used by functions are not registered, so we need to
  # whitelist them.  # TODO(irving): Do something better here.
  op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
  for op in used_ops:
    if op not in registered_ops and op not in op_whitelist:
      raise ValueError("Op %s is used by the graph, but is not registered" % op)

  # Build the stripped op list in sorted order
  return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops)
                               if op in registered_ops]) 
Example #20
Source File: quantize_graph.py    From tensorflow-for-poets-2 with Apache License 2.0 5 votes vote down vote up
def apply_final_node_renames(self):
    """Applies node renames in self.final_node_renames to self.output_graph."""
    old_graph = self.output_graph
    self.output_graph = graph_pb2.GraphDef()
    for node in old_graph.node:
      node.name = self.final_node_renames.get(node.name, node.name)
      for index, input_name in enumerate(node.input):
        node_name = node_name_from_input(input_name)
        input_full_name = ensure_tensor_name_has_port(input_name)
        if node_name in self.final_node_renames:
          node.input[index] = "%s%s" % (self.final_node_renames[node_name],
                                        input_full_name[len(node_name):])
      self.add_output_graph_node(node)
    return self.output_graph 
Example #21
Source File: test_tfcoreml.py    From MMdnn with MIT License 5 votes vote down vote up
def _test_coreml_model_image_input(self, tf_model_path, coreml_model,
      input_tensor_name, output_tensor_name, img_size, useCPUOnly = False):
    """Test single image input conversions.
    tf_model_path - the TF model
    coreml_model - converted CoreML model
    input_tensor_name - the input image tensor name
    output_tensor_name - the output tensor name
    img_size - size of the image
    """

    img_np, img = _load_image(TEST_IMAGE ,resize_to=(img_size, img_size))
    img_tf = np.expand_dims(img_np, axis = 0)
    img_tf[:,:,:,0] = self.image_scale * img_tf[:,:,:,0] + self.red_bias
    img_tf[:,:,:,1] = self.image_scale * img_tf[:,:,:,1] + self.green_bias
    img_tf[:,:,:,2] = self.image_scale * img_tf[:,:,:,2] + self.blue_bias

    #evaluate the TF model
    tf.reset_default_graph()
    graph_def = graph_pb2.GraphDef()
    with open(tf_model_path, "rb") as f:
        graph_def.ParseFromString(f.read())
    g = tf.import_graph_def(graph_def)
    with tf.Session(graph=g) as sess:
      image_input_tensor = sess.graph.get_tensor_by_name('import/' + input_tensor_name)
      output = sess.graph.get_tensor_by_name('import/' + output_tensor_name)
      tf_out = sess.run(output,feed_dict={image_input_tensor: img_tf})
    if len(tf_out.shape) == 4:
      tf_out = np.transpose(tf_out, (0,3,1,2))
    tf_out_flatten = tf_out.flatten()

    #evaluate CoreML
    coreml_input_name = input_tensor_name.replace(':', '__').replace('/', '__')
    coreml_output_name = output_tensor_name.replace(':', '__').replace('/', '__')
    coreml_input = {coreml_input_name: img}

    #Test the default CoreML evaluation
    coreml_out = coreml_model.predict(coreml_input, useCPUOnly = useCPUOnly)[coreml_output_name]
    coreml_out_flatten = coreml_out.flatten()
    self._compare_tf_coreml_outputs(tf_out_flatten, coreml_out_flatten) 
Example #22
Source File: test_tfcoreml.py    From MMdnn with MIT License 5 votes vote down vote up
def _test_coreml_model_image_input(tf_model_path, coreml_model,
      input_tensor_name, output_tensor_name, img_size, useCPUOnly = False):
    """Test single image input conversions.
    tf_model_path - the TF model
    coreml_model - converted CoreML model
    input_tensor_name - the input image tensor name
    output_tensor_name - the output tensor name
    img_size - size of the image
    """

    img_np, img = _load_image(TEST_IMAGE ,resize_to=(img_size, img_size))
    img_tf = np.expand_dims(img_np, axis = 0)
    img_tf[:,:,:,0] = 2.0/255 * img_tf[:,:,:,0] - 1
    img_tf[:,:,:,1] = 2.0/255 * img_tf[:,:,:,1] - 1
    img_tf[:,:,:,2] = 2.0/255 * img_tf[:,:,:,2] - 1

    #evaluate the TF model
    tf.reset_default_graph()
    graph_def = graph_pb2.GraphDef()
    with open(tf_model_path, "rb") as f:
        graph_def.ParseFromString(f.read())
    g = tf.import_graph_def(graph_def)
    with tf.Session(graph=g) as sess:
      image_input_tensor = sess.graph.get_tensor_by_name('import/' + input_tensor_name)
      output = sess.graph.get_tensor_by_name('import/' + output_tensor_name)
      tf_out = sess.run(output,feed_dict={image_input_tensor: img_tf})
    if len(tf_out.shape) == 4:
      tf_out = np.transpose(tf_out, (0,3,1,2))
    tf_out_flatten = tf_out.flatten()

    #evaluate CoreML
    coreml_input_name = input_tensor_name.replace(':', '__').replace('/', '__')
    coreml_output_name = output_tensor_name.replace(':', '__').replace('/', '__')
    coreml_input = {coreml_input_name: img}

    #Test the default CoreML evaluation
    coreml_out = coreml_model.predict(coreml_input, useCPUOnly = useCPUOnly)[coreml_output_name]
    coreml_out_flatten = coreml_out.flatten()
    print (coreml_out_flatten)
    # compare_tf_coreml_outputs(tf_out_flatten, coreml_out_flatten) 
Example #23
Source File: print_selective_registration_header.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
  """Gets the ops and kernels needed from the model files."""
  ops = set()

  for proto_file in proto_files:
    tf.logging.info('Loading proto file %s', proto_file)
    # Load GraphDef.
    file_data = tf.gfile.GFile(proto_file).read()
    if proto_fileformat == 'rawproto':
      graph_def = graph_pb2.GraphDef.FromString(file_data)
    else:
      assert proto_fileformat == 'textproto'
      graph_def = text_format.Parse(file_data, graph_pb2.GraphDef())

    # Find all ops and kernels used by the graph.
    for node_def in graph_def.node:
      if not node_def.device:
        node_def.device = '/cpu:0'
      kernel_class = pywrap_tensorflow.TryFindKernelClass(
          node_def.SerializeToString())
      if kernel_class:
        op_and_kernel = (str(node_def.op), kernel_class.decode('utf-8'))
        if op_and_kernel not in ops:
          ops.add(op_and_kernel)
      else:
        print(
            'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr)

  # Add default ops.
  if default_ops_str != 'all':
    for s in default_ops_str.split(','):
      op, kernel = s.split(':')
      op_and_kernel = (op, kernel)
      if op_and_kernel not in ops:
        ops.add(op_and_kernel)

  return list(sorted(ops)) 
Example #24
Source File: ops.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def graph_def_versions(self):
    """The GraphDef version information of this graph.

    For details on the meaning of each version, see
    [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).

    Returns:
      A `VersionDef`.
    """
    return self._graph_def_versions 
Example #25
Source File: control_flow_ops_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _StripGraph(self, gd):
    """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
    return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) 
Example #26
Source File: tf.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def get_workload(model_path):
    """ Import workload from frozen protobuf

    Parameters
    ----------
    model_path: str
        model_path on remote repository to download from.

    Returns
    -------
    graph_def: graphdef
        graph_def is the tensorflow workload for mobilenet.

    """

    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
    model_name = os.path.basename(model_path)
    model_url = os.path.join(repo_base, model_path)

    from mxnet.gluon.utils import download

    temp = util.tempdir()
    path_model = temp.relpath(model_name)

    download(model_url, path_model)

    # Creates graph from saved graph_def.pb.
    with tf.gfile.FastGFile(path_model, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
        temp.remove()
        return graph_def

#######################################################################
# PTB LSTMBlockCell Model
# ----------------------- 
Example #27
Source File: load_graph_nms_v2.py    From PythonPilot with Apache License 2.0 5 votes vote down vote up
def load_frozen_graph_without_split(self):
        """
        Load frozen_graph.
        """
        model_path = self.cfg['model_path']

        tf.reset_default_graph()

        graph_def = tf.GraphDef()
        with tf.gfile.GFile(model_path, 'rb') as fid:
            serialized_graph = fid.read()
            graph_def.ParseFromString(serialized_graph)
            # force CPU device placement for NMS ops
            for node in graph_def.node:
                if 'BatchMultiClassNonMaxSuppression' in node.name:
                    node.device = '/device:CPU:0'
                else:
                    node.device = '/device:GPU:0'
            tf.import_graph_def(graph_def, name='')

        #self.print_graph_operation_by_name(detection_graph, "Postprocessor/Slice")
        #self.print_graph_operation_by_name(detection_graph, "Postprocessor/ExpandDims_1")
        #self.print_graph_operation_by_name(detection_graph, "Postprocessor/stack_1")
        """
        return
        """
        return tf.get_default_graph() 
Example #28
Source File: meta_graph.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def stripped_op_list_for_graph(graph_def):
  """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.

  Raises:
    ValueError: If an unregistered op is used.
  """
  # This is the Python equivalent of StrippedOpListForGraph in C++.
  # Unfortunately, since the Python op registry can differ from that in C++, we
  # can't remove the duplication using swig (at least naively).
  # TODO(irving): Support taking graphs directly.

  used_ops = ops_used_by_graph_def(graph_def)

  # Verify that all used ops are registered.
  registered_ops = op_def_registry.get_registered_ops()
  # These internal ops used by functions are not registered, so we need to
  # whitelist them.  # TODO(irving): Do something better here.
  op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
  for op in used_ops:
    if op not in registered_ops and op not in op_whitelist:
      raise ValueError("Op %s is used by the graph, but is not registered" % op)

  # Build the stripped op list in sorted order
  return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops)
                               if op in registered_ops]) 
Example #29
Source File: meta_graph.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _read_file(filename):
  """Reads a file containing `GraphDef` and returns the protocol buffer.

  Args:
    filename: `graph_def` filename including the path.

  Returns:
    A `GraphDef` protocol buffer.

  Raises:
    IOError: If the file doesn't exist, or cannot be successfully parsed.
  """
  graph_def = graph_pb2.GraphDef()
  if not file_io.file_exists(filename):
    raise IOError("File %s does not exist." % filename)
  # First try to read it as a binary file.
  file_content = file_io.read_file_to_string(filename)
  try:
    graph_def.ParseFromString(file_content)
    return graph_def
  except Exception:  # pylint: disable=broad-except
    pass

  # Next try to read it as a text file.
  try:
    text_format.Merge(file_content.decode("utf-8"), graph_def)
  except text_format.ParseError as e:
    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))

  return graph_def 
Example #30
Source File: freeze_graph.py    From deepnlp with MIT License 5 votes vote down vote up
def _parse_input_graph_proto(input_graph, input_binary):
  """Parser input tensorflow graph into GraphDef proto."""
  if not gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1
  input_graph_def = graph_pb2.GraphDef()
  mode = "rb" if input_binary else "r"
  with gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)
  return input_graph_def