Python graphviz.Digraph() Examples

The following are 30 code examples of graphviz.Digraph(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module graphviz , or try the search function .
Example #1
Source File: graph.py    From CHAID with Apache License 2.0 7 votes vote down vote up
def render(self, path, view):
        if path is None:
            path = os.path.join("trees", "{:%Y-%m-%d %H:%M:%S}.gv".format(datetime.now()))
        with TemporaryDirectory() as self.tempdir:
            g = Digraph(
                format="png",
                graph_attr={"splines": "ortho"},
                node_attr={"shape": "plaintext", "labelloc": "b"},
            )
            for node in self.tree:
                image = self.bar_chart(node)
                g.node(str(node.node_id), image=image)
                if node.parent is not None:
                    edge_label = "     ({})     \n ".format(', '.join(map(str, node.choices)))
                    g.edge(str(node.parent), str(node.node_id), xlabel=edge_label)
            g.render(path, view=view) 
Example #2
Source File: output.py    From tributary with Apache License 2.0 7 votes vote down vote up
def GraphViz(node):
    d = Graph(node)

    from graphviz import Digraph
    dot = Digraph("Graph", strict=False)
    dot.format = 'png'

    def rec(nodes, parent):
        for d in nodes:
            if not isinstance(d, dict):
                dot.node(d, shape=d._graphvizshape)
                dot.edge(d, parent)

            else:
                for k in d:
                    dot.node(k._name, shape=k._graphvizshape)
                    rec(d[k], k)
                    dot.edge(k._name, parent._name)

    for k in d:
        dot.node(k._name, shape=k._graphvizshape)
        rec(d[k], k)

    return dot 
Example #3
Source File: tree.py    From pyth with MIT License 6 votes vote down vote up
def disp_tree(trees):
    graph = Digraph()
    count = 0

    def add(tree, count):
        if not tree:
            return count
        root = count
        graph.node(str(root), label=tree[0])
        for subtree in tree[1:]:
            if subtree:
                count += 1
                graph.edge(str(root), str(count))
                count = add(subtree, count)
        return count
    for tree in trees:
        count = add(tree, count) + 1
    graph.render('tree-rep.gv', view=True) 
Example #4
Source File: api.py    From eval-nas with MIT License 6 votes vote down vote up
def visualize(self):
        """Creates a dot graph. Can be visualized in colab directly."""
        num_vertices = np.shape(self.matrix)[0]
        try:
            import graphviz
            g = graphviz.Digraph()
            g.node(str(0), 'input')
            for v in range(1, num_vertices - 1):
                g.node(str(v), self.ops[v])
            g.node(str(num_vertices - 1), 'output')

            for src in range(num_vertices - 1):
                for dst in range(src + 1, num_vertices):
                    if self.matrix[src, dst]:
                        g.edge(str(src), str(dst))

            return g
        except ImportError as e:
            print(e) 
Example #5
Source File: many_to_one.py    From matchpy with MIT License 6 votes vote down vote up
def _as_graph(self, finals: Optional[List[str]]) -> Digraph:  # pragma: no cover
        if Digraph is None:
            raise ImportError('The graphviz package is required to draw the graph.')
        graph = Digraph()
        if finals is None:
            patterns = [
                '{}: {} with {}'.format(
                    self._colored_pattern(i), html.escape(str(p.expression)), self._format_constraint_set(c)
                ) for i, (p, l, c) in enumerate(self.patterns)
            ]
            graph.node('patterns', '<<b>Patterns:</b><br/>\n{}>'.format('<br/>\n'.join(patterns)), {'shape': 'box'})

        self._make_graph_nodes(graph, finals)
        if finals is None:
            constraints = [
                '{}: {} for {}'.format(self._colored_constraint(i), html.escape(str(c)), self._format_pattern_set(p))
                for i, (c, p) in enumerate(self.constraints)
            ]
            graph.node(
                'constraints', '<<b>Constraints:</b><br/>\n{}>'.format('<br/>\n'.join(constraints)), {'shape': 'box'}
            )
        self._make_graph_edges(graph)
        return graph 
Example #6
Source File: graph_init.py    From hydra-python-agent with MIT License 6 votes vote down vote up
def main(self,new_url,api_doc,check_commit):
        redis_connection = RedisProxy()
        redis_con = redis_connection.get_connection()
        self.url = new_url
        self.redis_graph = Graph("apigraph", redis_con)
        print("loading... of graph")
        self.get_endpoints(api_doc, redis_con)
        if check_commit:
            print("commiting")
            self.redis_graph.commit()
            # creating whole the graph in redis
            print("done!!!!")
        # uncomment below 2 lines for getting nodes for whole graph
        # for node in self.redis_graph.nodes.values():
        #    print("\n", node.alias)
        # uncomment the below lines for show the graph stored in redis
        # g = Digraph('redis_graph', filename='hydra_graph.gv')
        # using graphviz for visualization of graph stored in redis
        # for edge in self.redis_graph.edges:
        #    g.edge(edge.src_node.alias, edge.dest_node.alias)
        # g.view()
        # see the graph generated by graphviz 
Example #7
Source File: utilities.py    From Ithemal with MIT License 6 votes vote down vote up
def draw(self, to_file=False, file_name=None, view=True):
        if to_file and not file_name:
            file_name = tempfile.NamedTemporaryFile(suffix='.gv').name

        from graphviz import Digraph

        dot = Digraph()
        for instr in self.instrs:
            dot.node(str(id(instr)), str(instr))
            for child in instr.children:
                dot.edge(str(id(instr)), str(id(child)))

        if to_file:
            dot.render(file_name, view=view)
            return dot, file_name
        else:
            return dot 
Example #8
Source File: state_machine.py    From vivarium with GNU General Public License v3.0 6 votes vote down vote up
def to_dot(self):
        """Produces a ball and stick graph of this state machine.

        Returns
        -------
        `graphviz.Digraph`
            A ball and stick visualization of this state machine.
        """
        from graphviz import Digraph
        dot = Digraph(format='png')
        for state in self.states:
            if isinstance(state, TransientState):
                dot.node(state.state_id, style='dashed')
            else:
                dot.node(state.state_id)
            for transition in state.transition_set:
                dot.edge(state.state_id, transition.output_state.state_id, transition.name)
        return dot 
Example #9
Source File: draw_cfg.py    From pyta with GNU General Public License v3.0 6 votes vote down vote up
def display(cfgs: Dict[NodeNG, ControlFlowGraph],
            filename: str, view: bool = True) -> None:
    graph = graphviz.Digraph(name=filename, **GRAPH_OPTIONS)
    for node, cfg in cfgs.items():
        if isinstance(node, astroid.Module):
            subgraph_label = '__main__'
        elif isinstance(node, astroid.FunctionDef):
            subgraph_label = node.name
        else:
            continue
        with graph.subgraph(name=f'cluster_{id(node)}') as c:
            visited = set()
            _visit(cfg.start, c, visited)
            for block in cfg.unreachable_blocks:
                _visit(block, c, visited)
            c.attr(label=subgraph_label)

    graph.render(filename, view=view) 
Example #10
Source File: turbinia_job_graph.py    From turbinia with Apache License 2.0 6 votes vote down vote up
def create_graph():
  """Create graph of relationships between Turbinia jobs and evidence.

  Returns:
    Instance of graphviz.dot.Digraph
  """
  dot = graphviz.Digraph(comment='Turbinia Evidence graph', format='png')
  for _, job in jobs_manager.JobsManager.GetJobs():
    dot.node(job.NAME)
    for evidence in job.evidence_input:
      dot.node(evidence.__name__, shape='box')
      dot.edge(evidence.__name__, job.NAME)

    for evidence in job.evidence_output:
      dot.node(evidence.__name__, shape='box')
      dot.edge(job.NAME, evidence.__name__)
  return dot 
Example #11
Source File: regex_common.py    From acsploit with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def show_graph(self):
        g = Digraph("Automata Graph", filename=tempfile.mktemp(), format='png')
        g.attr('node', label="")

        for n in self.states:
            string = Automata.stringify_node(n)
            if n in self.accepting_states and n != self.initial_state:
                g.node(string, string, shape='doublecircle')
            elif n not in self.accepting_states and n == self.initial_state:
                g.node(string, string, shape='octagon')
            elif n in self.accepting_states and n == self.initial_state:
                g.node(string, string, shape='doubleoctagon')
            else:
                g.node(string, string, shape='circle')

        for s, transitions in self.transitions.items():
            for d, t in transitions.items():
                g.edge(Automata.stringify_node(s),
                       Automata.stringify_node(d),
                       str(t).replace('\\', '\\\\'))
        g.view(cleanup=True) 
Example #12
Source File: draw_cfg.py    From pyta with GNU General Public License v3.0 6 votes vote down vote up
def _visit(block: CFGBlock,
           graph: graphviz.Digraph, visited: Set[int]) -> None:
    node_id = f'{graph.name}_{block.id}'
    if node_id in visited:
        return

    label = '\n'.join([s.as_string() for s in block.statements]) + '\n'
    # Need to escape backslashes explicitly.
    label = label.replace('\\', '\\\\')
    # \l is used for left alignment.
    label = label.replace('\n', '\\l')

    fill_color = 'grey93' if not block.reachable else 'white'

    graph.node(node_id, label=label, fillcolor=fill_color, style='filled')
    visited.add(node_id)

    for edge in block.successors:
        graph.edge(node_id, f'{graph.name}_{edge.target.id}')
        _visit(edge.target, graph, visited) 
Example #13
Source File: exnviz.py    From ngraph-python with Apache License 2.0 6 votes vote down vote up
def begin_pass(self, filename=None, **kwargs):
        super(ExVizPass, self).begin_pass(**kwargs)
        try:
            import graphviz
        except ImportError:
            raise ImportError("You tried to use the ShowGraph transformer pass but did "
                              "not have the python graphviz library installed")
        if filename is None:
            filename = self.filename
        self.exops_with_nodes = set()
        self.exops_without_nodes = set()
        self.tensors_with_nodes = set()
        self.tensors_without_nodes = set()

        # Get all ops from this set
        self.graph = graphviz.Digraph(name=filename,
                                      # node_attr={'shape': 'box', 'style': 'rounded'},
                                      graph_attr={'nodesep': '.5',
                                                  'ranksep': '.5'}) 
Example #14
Source File: exnviz.py    From ngraph-python with Apache License 2.0 6 votes vote down vote up
def __init__(self, **kwargs):
        super(ExVizPass, self).__init__()
        self.show_axes = kwargs.pop('show_axes', True)
        self.show_all_metadata = kwargs.pop('show_all_metadata', True)
        self.subgraph_attr = kwargs.pop('subgraph_attr', None)
        self.exops_with_nodes = set()
        self.exops_without_nodes = set()
        self.filename = kwargs.pop('filename', 'Digraph')
        self.view = kwargs.pop('view', True)
        self.cleanup = kwargs.pop('cleanup', True)
        self.show_tensors = kwargs.pop('show_tensors', False)
        output_directory = kwargs.pop('output_directory', '.')

        if self.view:
            if output_directory is None:
                output_directory = tempfile.mkdtemp()
        self.output_directory = output_directory 
Example #15
Source File: graph.py    From pipedream with MIT License 6 votes vote down vote up
def to_dot(self, arch):
        dot = graphviz.Digraph()
        for node in self.nodes.values():
            node_desc = "%s\n[forward_compute_time=%.3f,backward_compute_time=%.3f,activation_size=%s,parameter_size=%.1f]" % (
                node.node_desc, node.forward_compute_time, node.backward_compute_time,
                node.activation_size, node.parameter_size)
            if node.stage_id is not None:
                color = self._colors[node.stage_id % len(self._colors)]
                dot.node(node.node_id, node_desc,
                   color=color, style='filled')
            else:
                dot.node(node.node_id, node_desc)
        for node in self.nodes.values():
            if node.node_id not in self.edges:
                continue
            for out_node in self.edges[node.node_id]:
                dot.edge(node.node_id, out_node.node_id)
        dot.render(arch) 
Example #16
Source File: notifications.py    From RAFCON with Eclipse Public License 1.0 6 votes vote down vote up
def enable_debugging():
    from graphviz import Digraph
    global debugging_enabled, notification_graph_to_render, dot_node_sequence_number, existing_dot_nodes_to_colors
    global nodes, edges
    existing_dot_nodes_to_colors = dict()
    debugging_enabled = True
    # does not work as all edges with the same source and endpoint are merged
    # dot_graph = Digraph(comment='Our fancy debugging graph', graph_attr={"concentrate": "true"})
    # does not change anything
    # dot_graph = Digraph(comment='Our fancy debugging graph', graph_attr={"labelfloat": "true"})
    # "ortho" does not work for all engines, for others it does not do anything
    # dot_graph = Digraph(comment='Our fancy debugging graph', graph_attr={"splines": "compound"})
    # dot_graph = Digraph(comment='Our fancy debugging graph', graph_attr={"splines": "compound", "overlap": "false"})
    notification_graph_to_render = Digraph(name='notification_graph_to_render')
    dot_node_sequence_number = 0
    nodes = {}
    edges = OrderedDict() 
Example #17
Source File: model.py    From staticfg with Apache License 2.0 5 votes vote down vote up
def _build_visual(self, format='pdf', calls=True):
        graph = gv.Digraph(name='cluster'+self.name, format=format,
                           graph_attr={'label': self.name})
        self._visit_blocks(graph, self.entryblock, visited=[], calls=calls)

        # Build the subgraphs for the function definitions in the CFG and add
        # them to the graph.
        for subcfg in self.functioncfgs:
            subgraph = self.functioncfgs[subcfg]._build_visual(format=format,
                                                               calls=calls)
            graph.subgraph(subgraph)

        return graph 
Example #18
Source File: utils.py    From ProxImaL with MIT License 5 votes vote down vote up
def graph_visualize(prox_fns, filename = None):
    import graphviz
    from IPython.display import display
    dot = graphviz.Digraph()
    nodes = {}
    def node(obj):
        if not obj in nodes:
            nodes[obj] = 'N%d' % len(nodes)
        return nodes[obj]
    from proximal.prox_fns.prox_fn import ProxFn
    for pfn in prox_fns:
        dot.node(node(pfn), str(pfn))
        activenodes = [pfn.lin_op]
        while len(activenodes) > 0:
            n = activenodes.pop(0)
            if not n in nodes:
                dot.node(node(n), str(type(n)))
                activenodes.extend(n.input_nodes)
        dot.edge(nodes[pfn.lin_op], nodes[pfn])
        activenodes = [pfn.lin_op]
        visited = set()
        while len(activenodes) > 0:
            n = activenodes.pop(0)
            if not n in visited:
                visited.add(n)
                activenodes.extend(n.input_nodes)
                for inn in n.input_nodes:
                    dot.edge(nodes[inn], nodes[n])
    if filename is None:
        display(dot)
    else:
        dot.render(filename) 
Example #19
Source File: graph_monitor.py    From AMS with Apache License 2.0 5 votes vote down vote up
def get_dots(self):
        f = Digraph('finite_state_machine')
        f.attr(rankdir='LR', size='8,5')
        f.attr('node', shape='circle')
        edges = {}
        for from_class, from_id_to_class_id_relations in self.__relations.items():
            for from_id, to_class_id_relations in from_id_to_class_id_relations.items():
                for to_class, to_ids in to_class_id_relations.items():
                    edges[(from_class, to_class, from_class+" -> "+to_class)] = None
        for edge in set(list(edges.keys())):
            f.edge(edge[0], edge[1], label=edge[2])
        class_graph_dot = f.source

        f = Digraph('finite_state_machine')
        f.attr(rankdir='LR', size='8,5')
        f.attr('node', shape='circle')
        edges = {}
        for from_class, from_id_to_class_id_relations in self.__relations.items():
            for from_id, to_class_id_relations in from_id_to_class_id_relations.items():
                for to_class, to_ids in to_class_id_relations.items():
                    if from_class == "traffic_signal":
                        edges[(from_class, to_class, from_class + " -> " + to_class)] = None
                    else:
                        for to_id in to_ids:
                            edges[(from_class, to_class, from_id + " -> " + to_id)] = None
        for edge in set(list(edges.keys())):
            f.edge(edge[0], edge[1], label=edge[2])
        graph_dot = f.source

        return graph_dot, class_graph_dot 
Example #20
Source File: cuda_codegen.py    From ProxImaL with MIT License 5 votes vote down vote up
def visualize(self, dot = None):
        import graphviz
        from IPython.display import display
        root = False
        if not dot:
            root = True
            dot = graphviz.Digraph()
        for csg in self.dependent_subgraphs:
            csg.visualize(dot)
        nodes = {}
        visited = {}
        active = [self.end]
        while len(active) > 0:
            n = active.pop(0)
            if not n in nodes:
                nodes[n] = 'N%d' % len(nodes)
                dot.node(nodes[n], str(type(n)))
                try:
                    innodes = self.input_nodes(n)
                    for inn in innodes:
                        active.append(inn)
                except KeyError:
                    pass
        active = [self.end]
        while len(active) > 0:
            n = active.pop(0)
            if not n in visited:
                visited[n] = True
                try:
                    innodes = self.input_nodes(n)
                    for inn in self.input_nodes(n):
                        active.append(inn)
                        dot.edge(nodes[inn], nodes[n])
                except KeyError:
                    pass
        if root:
            display(dot) 
Example #21
Source File: modeling_framework.py    From dmipy with MIT License 5 votes vote down vote up
def visualize_model_setup(
            self, view=True, cleanup=True, with_parameters=False,
            im_format='png'):
        """
        Visualizes MultiCompartmentModel setup using graphviz module. It uses
        the uuid module to create a unique identifier for each model in the
        MultiCompartmentModel to make sure each node is referenced in a unique
        way.

        If cleanup is set to False it will save the PDF of the graph in the
        current working directory.

        If with_parameters is set to true, it will include all the parameters
        of each model in the graph. Note the graph will ignore any parameter
        links that may have already been imposed (e.g. parameter equality or
        fixed parameters).

        Parameters
        ----------
        view: boolean,
            Whether or not to visualize the graph in a popup screen.
        cleanup: boolean,
            Whether or not to delete the PDF file of the model setup.
        with_parameters: boolean,
            Whether or not to also visualize the parameters of each model.
        """
        if not have_graphviz:
            raise ImportError('graphviz package not installed.')
        dot = Digraph('Model Setup', format=im_format)
        base_model = self.__class__.__name__
        base_uuid = str(uuid4())
        dot.node(base_uuid, base_model)
        self._add_recursive_graph_node(dot, base_uuid, self, with_parameters)
        dot.render('Model Setup', view=view, cleanup=cleanup) 
Example #22
Source File: dot_renderer.py    From airflow with Apache License 2.0 5 votes vote down vote up
def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.Digraph:
    """
    Renders the DAG object to the DOT object.

    If an task instance list is passed, the nodes will be painted according to task statuses.

    :param dag: DAG that will be rendered.
    :type dag: airflow.models.dag.DAG
    :param tis: List of task instances
    :type tis: Optional[List[TaskInstance]]
    :return: Graphviz object
    :rtype: graphviz.Digraph
    """
    dot = graphviz.Digraph(dag.dag_id, graph_attr={"rankdir": dag.orientation if dag.orientation else "LR",
                                                   "labelloc": "t", "label": dag.dag_id})
    states_by_task_id = None
    if tis is not None:
        states_by_task_id = {ti.task_id: ti.state for ti in tis}
    for task in dag.tasks:
        node_attrs = {
            "shape": "rectangle",
            "style": "filled,rounded",
        }
        if states_by_task_id is None:
            node_attrs.update({
                "color": _refine_color(task.ui_fgcolor),
                "fillcolor": _refine_color(task.ui_color),
            })
        else:
            state = states_by_task_id.get(task.task_id, State.NONE)
            node_attrs.update({
                "color": State.color_fg(state),
                "fillcolor": State.color(state),
            })
        dot.node(
            task.task_id,
            _attributes=node_attrs,
        )
        for downstream_task_id in task.downstream_task_ids:
            dot.edge(task.task_id, downstream_task_id)
    return dot 
Example #23
Source File: SimpleSchedulePass.py    From pymtl3 with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def dump_dag( top, V, E ):
  from graphviz import Digraph
  from pymtl3.dsl import CalleePort
  dot = Digraph()
  dot.graph_attr["rank"] = "same"
  dot.graph_attr["ratio"] = "compress"
  dot.graph_attr["margin"] = "0.1"

  for x in V:
    x_name = repr(x) if isinstance( x, CalleePort ) else x.__name__
    if x in top._dsl.all_update_ff:
      x_name += "_FF"
    try:
      x_host = repr(x.get_parent_object() if isinstance( x, CalleePort )
                    else top.get_update_block_host_component(x))
    except:
      x_host = ""
    dot.node( x_name +"\\n@" + x_host, shape="box")

  for (x, y) in E:
    x_name = repr(x) if isinstance( x, CalleePort ) else x.__name__
    if x in top._dsl.all_update_ff:
      x_name += "_FF"
    try:
      x_host = repr(x.get_parent_object() if isinstance( x, CalleePort )
                    else top.get_update_block_host_component(x))
    except:
      x_host = ""
    y_name = repr(y) if isinstance( y, CalleePort ) else y.__name__
    if y in top._dsl.all_update_ff:
      y_name += "_FF"
    try:
      y_host = repr(y.get_parent_object() if isinstance( y, CalleePort )
                    else top.get_update_block_host_component(y))
    except:
      y_host = ""

    dot.edge( x_name+"\\n@"+x_host, y_name+"\\n@"+y_host )
  dot.render( "/tmp/upblk-dag.gv", view=True ) 
Example #24
Source File: visualize.py    From ibis with Apache License 2.0 5 votes vote down vote up
def to_graph(expr, node_attr=None, edge_attr=None):
    stack = [(expr, expr._safe_name)]
    seen = set()
    g = gv.Digraph(
        node_attr=node_attr or DEFAULT_NODE_ATTRS, edge_attr=edge_attr or {}
    )

    g.attr(rankdir='BT')

    while stack:
        e, ename = stack.pop()
        vkey = e._key, ename

        if vkey not in seen:
            seen.add(vkey)

            vlabel = get_label(e, argname=ename)
            vhash = str(hash(vkey))
            g.node(vhash, label=vlabel)

            node = e.op()
            args = node.args
            for arg, name in zip(args, node.signature.names()):
                if isinstance(arg, ir.Expr):
                    u = arg, name
                    ukey = arg._key, name
                    uhash = str(hash(ukey))
                    ulabel = get_label(arg, argname=name)
                    g.node(uhash, label=ulabel)
                    g.edge(uhash, vhash)
                    stack.append(u)
    return g 
Example #25
Source File: tiny_gp_plus.py    From tiny_gp with GNU General Public License v3.0 5 votes vote down vote up
def draw_tree(self, fname, footer):
        dot = [Digraph()]
        dot[0].attr(kw='graph', label = footer)
        count = [0]
        self.draw(dot, count)
        Source(dot[0], filename = fname + ".gv", format="png").render()
        display(Image(filename = fname + ".gv.png")) 
Example #26
Source File: nist800_53viz.py    From compliancelib-python with GNU General Public License v3.0 5 votes vote down vote up
def __init__(self, id, vizformat='svg'):
        self.id = id
        # Config
        # self.base_path = "./"
        self.base_path = os.path.join(os.path.dirname(__file__), '')
        
        self.dep_dir = "data/dependencies/"
        self.out_dir = ""
        self.log_dir = "./"
        # graphviz image format
        self.vizformat = vizformat
        self.width = 2.5
        self.height = 2.5
        self.graph = functools.partial(gv.Graph, format=self.vizformat)
        self.digraph = functools.partial(gv.Digraph, format=self.vizformat)

        # Change these for a given run
        self.input_path = self.base_path + self.dep_dir
        self.output_path = self.base_path + self.out_dir

        # load graph
        self.dep_dict = self._load_graph_from_dependency_files()

        # load other 
        self.resolved = []
        self.nodes = []
        self.edges = []

        # resolve precursors
        self.precursor_list(self.dep_dict, self.id, self.nodes)
        self.precursor_controls = self.nodes 
Example #27
Source File: pytorch_visualize.py    From MaximumMarginGANs with MIT License 5 votes vote down vote up
def save_visualization(name, format='svg'):
    g = graphviz.Digraph(format=format)
    def sizestr(var):
        size = [int(i) for i in list(var.size())]
        return str(size)
    # add variable nodes
    for vid, var in vars.iteritems():
        if isinstance(var, nn.Parameter):
            g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='red')
        elif isinstance(var, Variable):
            g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='lightblue')
        else:
            assert False, var.__class__
    # add creator nodes
    for cid in func_trace:
        creator = funcs[cid]
        g.node(str(cid), label=str(creator.__class__.__name__), shape='rectangle', style='filled', fillcolor='orange')
    # add edges between creator and inputs
    for cid in func_trace:
        for iid in func_trace[cid]:
            g.edge(str(iid), str(cid))
    # add edges between outputs and creators
    for oid in var_trace:
        for cid in var_trace[oid]:
            g.edge(str(cid), str(oid))
    g.render(name) 
Example #28
Source File: graph.py    From ray-legacy with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def graph_to_graphviz(computation_graph):
  """
  Convert the computation graph to graphviz format.

  Args:
    computation_graph [graph_pb2.CompGraph]: protocol buffer description of
      the computation graph

  Returns:
    Graphviz description of the computation graph
  """
  dot = graphviz.Digraph(format="pdf")
  dot.node("op-root", shape="box")
  for (i, op) in enumerate(computation_graph.operation):
    if op.HasField("task"):
      dot.node("op" + str(i), shape="box", label=str(i) + "\n" + op.task.name.split(".")[-1])
      for res in op.task.result:
        dot.edge("op" + str(i), str(res))
    elif op.HasField("put"):
      dot.node("op" + str(i), shape="box", label=str(i) + "\n" + "put")
      dot.edge("op" + str(i), str(op.put.objectid))
    elif op.HasField("get"):
      dot.node("op" + str(i), shape="box", label=str(i) + "\n" + "get")
    creator_operationid = op.creator_operationid if op.creator_operationid != 2 ** 64 - 1 else "-root"
    dot.edge("op" + str(creator_operationid), "op" + str(i), style="dotted", constraint="false")
    for arg in op.task.arg:
      if len(arg.serialized_arg) == 0:
        dot.node(str(arg.objectid))
        dot.edge(str(arg.objectid), "op" + str(i))
  return dot 
Example #29
Source File: viz.py    From scene-graph-TF-release with MIT License 5 votes vote down vote up
def draw_graph(labels, rels, cfg):
    u = Digraph('sg', filename='sg.gv')
    u.body.append('size="6,6"')
    u.body.append('rankdir="LR"')
    u.node_attr.update(style='filled')

    out_dict = {'ind_to_class': cfg.ind_to_class, 'ind_to_predicate': cfg.ind_to_predicate}
    out_dict['labels'] = labels.tolist()
    out_dict['relations'] = rels

    rels = np.array(rels)
    rel_inds = rels[:,:2].ravel().tolist()
    name_list = []
    for i, l in enumerate(labels):
        if i in rel_inds:
            name = cfg.ind_to_class[l]
            name_suffix = 1
            obj_name = name
            while obj_name in name_list:
                obj_name = name + '_' + str(name_suffix)
                name_suffix += 1
            name_list.append(obj_name)
            u.node(str(i), label=obj_name, color='lightblue2')

    for rel in rels:
        edge_key = '%s_%s' % (rel[0], rel[1])
        u.node(edge_key, label=cfg.ind_to_predicate[rel[2]], color='red')

        u.edge(str(rel[0]), edge_key)
        u.edge(edge_key, str(rel[1]))

    u.view()

    return out_dict 
Example #30
Source File: dot.py    From Global-Second-order-Pooling-Convolutional-Networks with MIT License 5 votes vote down vote up
def make_dot_from_trace(trace):
    """ Produces graphs of torch.jit.trace outputs

    Example:
    >>> trace, = torch.jit.trace(model, args=(x,))
    >>> dot = make_dot_from_trace(trace)
    """
    torch.onnx._optimize_trace(trace, False)
    graph = trace.graph()
    list_of_nodes = parse(graph)

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')

    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

    for node in list_of_nodes:
        dot.node(node.name, label=node.name.replace('/', '\n'))
        if node.inputs:
            for inp in node.inputs:
                dot.edge(inp, node.name)

    resize_graph(dot)

    return dot