Python graphviz.Digraph() Examples
The following are 30
code examples of graphviz.Digraph().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
graphviz
, or try the search function
.
Example #1
Source File: graph.py From CHAID with Apache License 2.0 | 7 votes |
def render(self, path, view): if path is None: path = os.path.join("trees", "{:%Y-%m-%d %H:%M:%S}.gv".format(datetime.now())) with TemporaryDirectory() as self.tempdir: g = Digraph( format="png", graph_attr={"splines": "ortho"}, node_attr={"shape": "plaintext", "labelloc": "b"}, ) for node in self.tree: image = self.bar_chart(node) g.node(str(node.node_id), image=image) if node.parent is not None: edge_label = " ({}) \n ".format(', '.join(map(str, node.choices))) g.edge(str(node.parent), str(node.node_id), xlabel=edge_label) g.render(path, view=view)
Example #2
Source File: output.py From tributary with Apache License 2.0 | 7 votes |
def GraphViz(node): d = Graph(node) from graphviz import Digraph dot = Digraph("Graph", strict=False) dot.format = 'png' def rec(nodes, parent): for d in nodes: if not isinstance(d, dict): dot.node(d, shape=d._graphvizshape) dot.edge(d, parent) else: for k in d: dot.node(k._name, shape=k._graphvizshape) rec(d[k], k) dot.edge(k._name, parent._name) for k in d: dot.node(k._name, shape=k._graphvizshape) rec(d[k], k) return dot
Example #3
Source File: tree.py From pyth with MIT License | 6 votes |
def disp_tree(trees): graph = Digraph() count = 0 def add(tree, count): if not tree: return count root = count graph.node(str(root), label=tree[0]) for subtree in tree[1:]: if subtree: count += 1 graph.edge(str(root), str(count)) count = add(subtree, count) return count for tree in trees: count = add(tree, count) + 1 graph.render('tree-rep.gv', view=True)
Example #4
Source File: api.py From eval-nas with MIT License | 6 votes |
def visualize(self): """Creates a dot graph. Can be visualized in colab directly.""" num_vertices = np.shape(self.matrix)[0] try: import graphviz g = graphviz.Digraph() g.node(str(0), 'input') for v in range(1, num_vertices - 1): g.node(str(v), self.ops[v]) g.node(str(num_vertices - 1), 'output') for src in range(num_vertices - 1): for dst in range(src + 1, num_vertices): if self.matrix[src, dst]: g.edge(str(src), str(dst)) return g except ImportError as e: print(e)
Example #5
Source File: many_to_one.py From matchpy with MIT License | 6 votes |
def _as_graph(self, finals: Optional[List[str]]) -> Digraph: # pragma: no cover if Digraph is None: raise ImportError('The graphviz package is required to draw the graph.') graph = Digraph() if finals is None: patterns = [ '{}: {} with {}'.format( self._colored_pattern(i), html.escape(str(p.expression)), self._format_constraint_set(c) ) for i, (p, l, c) in enumerate(self.patterns) ] graph.node('patterns', '<<b>Patterns:</b><br/>\n{}>'.format('<br/>\n'.join(patterns)), {'shape': 'box'}) self._make_graph_nodes(graph, finals) if finals is None: constraints = [ '{}: {} for {}'.format(self._colored_constraint(i), html.escape(str(c)), self._format_pattern_set(p)) for i, (c, p) in enumerate(self.constraints) ] graph.node( 'constraints', '<<b>Constraints:</b><br/>\n{}>'.format('<br/>\n'.join(constraints)), {'shape': 'box'} ) self._make_graph_edges(graph) return graph
Example #6
Source File: graph_init.py From hydra-python-agent with MIT License | 6 votes |
def main(self,new_url,api_doc,check_commit): redis_connection = RedisProxy() redis_con = redis_connection.get_connection() self.url = new_url self.redis_graph = Graph("apigraph", redis_con) print("loading... of graph") self.get_endpoints(api_doc, redis_con) if check_commit: print("commiting") self.redis_graph.commit() # creating whole the graph in redis print("done!!!!") # uncomment below 2 lines for getting nodes for whole graph # for node in self.redis_graph.nodes.values(): # print("\n", node.alias) # uncomment the below lines for show the graph stored in redis # g = Digraph('redis_graph', filename='hydra_graph.gv') # using graphviz for visualization of graph stored in redis # for edge in self.redis_graph.edges: # g.edge(edge.src_node.alias, edge.dest_node.alias) # g.view() # see the graph generated by graphviz
Example #7
Source File: utilities.py From Ithemal with MIT License | 6 votes |
def draw(self, to_file=False, file_name=None, view=True): if to_file and not file_name: file_name = tempfile.NamedTemporaryFile(suffix='.gv').name from graphviz import Digraph dot = Digraph() for instr in self.instrs: dot.node(str(id(instr)), str(instr)) for child in instr.children: dot.edge(str(id(instr)), str(id(child))) if to_file: dot.render(file_name, view=view) return dot, file_name else: return dot
Example #8
Source File: state_machine.py From vivarium with GNU General Public License v3.0 | 6 votes |
def to_dot(self): """Produces a ball and stick graph of this state machine. Returns ------- `graphviz.Digraph` A ball and stick visualization of this state machine. """ from graphviz import Digraph dot = Digraph(format='png') for state in self.states: if isinstance(state, TransientState): dot.node(state.state_id, style='dashed') else: dot.node(state.state_id) for transition in state.transition_set: dot.edge(state.state_id, transition.output_state.state_id, transition.name) return dot
Example #9
Source File: draw_cfg.py From pyta with GNU General Public License v3.0 | 6 votes |
def display(cfgs: Dict[NodeNG, ControlFlowGraph], filename: str, view: bool = True) -> None: graph = graphviz.Digraph(name=filename, **GRAPH_OPTIONS) for node, cfg in cfgs.items(): if isinstance(node, astroid.Module): subgraph_label = '__main__' elif isinstance(node, astroid.FunctionDef): subgraph_label = node.name else: continue with graph.subgraph(name=f'cluster_{id(node)}') as c: visited = set() _visit(cfg.start, c, visited) for block in cfg.unreachable_blocks: _visit(block, c, visited) c.attr(label=subgraph_label) graph.render(filename, view=view)
Example #10
Source File: turbinia_job_graph.py From turbinia with Apache License 2.0 | 6 votes |
def create_graph(): """Create graph of relationships between Turbinia jobs and evidence. Returns: Instance of graphviz.dot.Digraph """ dot = graphviz.Digraph(comment='Turbinia Evidence graph', format='png') for _, job in jobs_manager.JobsManager.GetJobs(): dot.node(job.NAME) for evidence in job.evidence_input: dot.node(evidence.__name__, shape='box') dot.edge(evidence.__name__, job.NAME) for evidence in job.evidence_output: dot.node(evidence.__name__, shape='box') dot.edge(job.NAME, evidence.__name__) return dot
Example #11
Source File: regex_common.py From acsploit with BSD 3-Clause "New" or "Revised" License | 6 votes |
def show_graph(self): g = Digraph("Automata Graph", filename=tempfile.mktemp(), format='png') g.attr('node', label="") for n in self.states: string = Automata.stringify_node(n) if n in self.accepting_states and n != self.initial_state: g.node(string, string, shape='doublecircle') elif n not in self.accepting_states and n == self.initial_state: g.node(string, string, shape='octagon') elif n in self.accepting_states and n == self.initial_state: g.node(string, string, shape='doubleoctagon') else: g.node(string, string, shape='circle') for s, transitions in self.transitions.items(): for d, t in transitions.items(): g.edge(Automata.stringify_node(s), Automata.stringify_node(d), str(t).replace('\\', '\\\\')) g.view(cleanup=True)
Example #12
Source File: draw_cfg.py From pyta with GNU General Public License v3.0 | 6 votes |
def _visit(block: CFGBlock, graph: graphviz.Digraph, visited: Set[int]) -> None: node_id = f'{graph.name}_{block.id}' if node_id in visited: return label = '\n'.join([s.as_string() for s in block.statements]) + '\n' # Need to escape backslashes explicitly. label = label.replace('\\', '\\\\') # \l is used for left alignment. label = label.replace('\n', '\\l') fill_color = 'grey93' if not block.reachable else 'white' graph.node(node_id, label=label, fillcolor=fill_color, style='filled') visited.add(node_id) for edge in block.successors: graph.edge(node_id, f'{graph.name}_{edge.target.id}') _visit(edge.target, graph, visited)
Example #13
Source File: exnviz.py From ngraph-python with Apache License 2.0 | 6 votes |
def begin_pass(self, filename=None, **kwargs): super(ExVizPass, self).begin_pass(**kwargs) try: import graphviz except ImportError: raise ImportError("You tried to use the ShowGraph transformer pass but did " "not have the python graphviz library installed") if filename is None: filename = self.filename self.exops_with_nodes = set() self.exops_without_nodes = set() self.tensors_with_nodes = set() self.tensors_without_nodes = set() # Get all ops from this set self.graph = graphviz.Digraph(name=filename, # node_attr={'shape': 'box', 'style': 'rounded'}, graph_attr={'nodesep': '.5', 'ranksep': '.5'})
Example #14
Source File: exnviz.py From ngraph-python with Apache License 2.0 | 6 votes |
def __init__(self, **kwargs): super(ExVizPass, self).__init__() self.show_axes = kwargs.pop('show_axes', True) self.show_all_metadata = kwargs.pop('show_all_metadata', True) self.subgraph_attr = kwargs.pop('subgraph_attr', None) self.exops_with_nodes = set() self.exops_without_nodes = set() self.filename = kwargs.pop('filename', 'Digraph') self.view = kwargs.pop('view', True) self.cleanup = kwargs.pop('cleanup', True) self.show_tensors = kwargs.pop('show_tensors', False) output_directory = kwargs.pop('output_directory', '.') if self.view: if output_directory is None: output_directory = tempfile.mkdtemp() self.output_directory = output_directory
Example #15
Source File: graph.py From pipedream with MIT License | 6 votes |
def to_dot(self, arch): dot = graphviz.Digraph() for node in self.nodes.values(): node_desc = "%s\n[forward_compute_time=%.3f,backward_compute_time=%.3f,activation_size=%s,parameter_size=%.1f]" % ( node.node_desc, node.forward_compute_time, node.backward_compute_time, node.activation_size, node.parameter_size) if node.stage_id is not None: color = self._colors[node.stage_id % len(self._colors)] dot.node(node.node_id, node_desc, color=color, style='filled') else: dot.node(node.node_id, node_desc) for node in self.nodes.values(): if node.node_id not in self.edges: continue for out_node in self.edges[node.node_id]: dot.edge(node.node_id, out_node.node_id) dot.render(arch)
Example #16
Source File: notifications.py From RAFCON with Eclipse Public License 1.0 | 6 votes |
def enable_debugging(): from graphviz import Digraph global debugging_enabled, notification_graph_to_render, dot_node_sequence_number, existing_dot_nodes_to_colors global nodes, edges existing_dot_nodes_to_colors = dict() debugging_enabled = True # does not work as all edges with the same source and endpoint are merged # dot_graph = Digraph(comment='Our fancy debugging graph', graph_attr={"concentrate": "true"}) # does not change anything # dot_graph = Digraph(comment='Our fancy debugging graph', graph_attr={"labelfloat": "true"}) # "ortho" does not work for all engines, for others it does not do anything # dot_graph = Digraph(comment='Our fancy debugging graph', graph_attr={"splines": "compound"}) # dot_graph = Digraph(comment='Our fancy debugging graph', graph_attr={"splines": "compound", "overlap": "false"}) notification_graph_to_render = Digraph(name='notification_graph_to_render') dot_node_sequence_number = 0 nodes = {} edges = OrderedDict()
Example #17
Source File: model.py From staticfg with Apache License 2.0 | 5 votes |
def _build_visual(self, format='pdf', calls=True): graph = gv.Digraph(name='cluster'+self.name, format=format, graph_attr={'label': self.name}) self._visit_blocks(graph, self.entryblock, visited=[], calls=calls) # Build the subgraphs for the function definitions in the CFG and add # them to the graph. for subcfg in self.functioncfgs: subgraph = self.functioncfgs[subcfg]._build_visual(format=format, calls=calls) graph.subgraph(subgraph) return graph
Example #18
Source File: utils.py From ProxImaL with MIT License | 5 votes |
def graph_visualize(prox_fns, filename = None): import graphviz from IPython.display import display dot = graphviz.Digraph() nodes = {} def node(obj): if not obj in nodes: nodes[obj] = 'N%d' % len(nodes) return nodes[obj] from proximal.prox_fns.prox_fn import ProxFn for pfn in prox_fns: dot.node(node(pfn), str(pfn)) activenodes = [pfn.lin_op] while len(activenodes) > 0: n = activenodes.pop(0) if not n in nodes: dot.node(node(n), str(type(n))) activenodes.extend(n.input_nodes) dot.edge(nodes[pfn.lin_op], nodes[pfn]) activenodes = [pfn.lin_op] visited = set() while len(activenodes) > 0: n = activenodes.pop(0) if not n in visited: visited.add(n) activenodes.extend(n.input_nodes) for inn in n.input_nodes: dot.edge(nodes[inn], nodes[n]) if filename is None: display(dot) else: dot.render(filename)
Example #19
Source File: graph_monitor.py From AMS with Apache License 2.0 | 5 votes |
def get_dots(self): f = Digraph('finite_state_machine') f.attr(rankdir='LR', size='8,5') f.attr('node', shape='circle') edges = {} for from_class, from_id_to_class_id_relations in self.__relations.items(): for from_id, to_class_id_relations in from_id_to_class_id_relations.items(): for to_class, to_ids in to_class_id_relations.items(): edges[(from_class, to_class, from_class+" -> "+to_class)] = None for edge in set(list(edges.keys())): f.edge(edge[0], edge[1], label=edge[2]) class_graph_dot = f.source f = Digraph('finite_state_machine') f.attr(rankdir='LR', size='8,5') f.attr('node', shape='circle') edges = {} for from_class, from_id_to_class_id_relations in self.__relations.items(): for from_id, to_class_id_relations in from_id_to_class_id_relations.items(): for to_class, to_ids in to_class_id_relations.items(): if from_class == "traffic_signal": edges[(from_class, to_class, from_class + " -> " + to_class)] = None else: for to_id in to_ids: edges[(from_class, to_class, from_id + " -> " + to_id)] = None for edge in set(list(edges.keys())): f.edge(edge[0], edge[1], label=edge[2]) graph_dot = f.source return graph_dot, class_graph_dot
Example #20
Source File: cuda_codegen.py From ProxImaL with MIT License | 5 votes |
def visualize(self, dot = None): import graphviz from IPython.display import display root = False if not dot: root = True dot = graphviz.Digraph() for csg in self.dependent_subgraphs: csg.visualize(dot) nodes = {} visited = {} active = [self.end] while len(active) > 0: n = active.pop(0) if not n in nodes: nodes[n] = 'N%d' % len(nodes) dot.node(nodes[n], str(type(n))) try: innodes = self.input_nodes(n) for inn in innodes: active.append(inn) except KeyError: pass active = [self.end] while len(active) > 0: n = active.pop(0) if not n in visited: visited[n] = True try: innodes = self.input_nodes(n) for inn in self.input_nodes(n): active.append(inn) dot.edge(nodes[inn], nodes[n]) except KeyError: pass if root: display(dot)
Example #21
Source File: modeling_framework.py From dmipy with MIT License | 5 votes |
def visualize_model_setup( self, view=True, cleanup=True, with_parameters=False, im_format='png'): """ Visualizes MultiCompartmentModel setup using graphviz module. It uses the uuid module to create a unique identifier for each model in the MultiCompartmentModel to make sure each node is referenced in a unique way. If cleanup is set to False it will save the PDF of the graph in the current working directory. If with_parameters is set to true, it will include all the parameters of each model in the graph. Note the graph will ignore any parameter links that may have already been imposed (e.g. parameter equality or fixed parameters). Parameters ---------- view: boolean, Whether or not to visualize the graph in a popup screen. cleanup: boolean, Whether or not to delete the PDF file of the model setup. with_parameters: boolean, Whether or not to also visualize the parameters of each model. """ if not have_graphviz: raise ImportError('graphviz package not installed.') dot = Digraph('Model Setup', format=im_format) base_model = self.__class__.__name__ base_uuid = str(uuid4()) dot.node(base_uuid, base_model) self._add_recursive_graph_node(dot, base_uuid, self, with_parameters) dot.render('Model Setup', view=view, cleanup=cleanup)
Example #22
Source File: dot_renderer.py From airflow with Apache License 2.0 | 5 votes |
def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.Digraph: """ Renders the DAG object to the DOT object. If an task instance list is passed, the nodes will be painted according to task statuses. :param dag: DAG that will be rendered. :type dag: airflow.models.dag.DAG :param tis: List of task instances :type tis: Optional[List[TaskInstance]] :return: Graphviz object :rtype: graphviz.Digraph """ dot = graphviz.Digraph(dag.dag_id, graph_attr={"rankdir": dag.orientation if dag.orientation else "LR", "labelloc": "t", "label": dag.dag_id}) states_by_task_id = None if tis is not None: states_by_task_id = {ti.task_id: ti.state for ti in tis} for task in dag.tasks: node_attrs = { "shape": "rectangle", "style": "filled,rounded", } if states_by_task_id is None: node_attrs.update({ "color": _refine_color(task.ui_fgcolor), "fillcolor": _refine_color(task.ui_color), }) else: state = states_by_task_id.get(task.task_id, State.NONE) node_attrs.update({ "color": State.color_fg(state), "fillcolor": State.color(state), }) dot.node( task.task_id, _attributes=node_attrs, ) for downstream_task_id in task.downstream_task_ids: dot.edge(task.task_id, downstream_task_id) return dot
Example #23
Source File: SimpleSchedulePass.py From pymtl3 with BSD 3-Clause "New" or "Revised" License | 5 votes |
def dump_dag( top, V, E ): from graphviz import Digraph from pymtl3.dsl import CalleePort dot = Digraph() dot.graph_attr["rank"] = "same" dot.graph_attr["ratio"] = "compress" dot.graph_attr["margin"] = "0.1" for x in V: x_name = repr(x) if isinstance( x, CalleePort ) else x.__name__ if x in top._dsl.all_update_ff: x_name += "_FF" try: x_host = repr(x.get_parent_object() if isinstance( x, CalleePort ) else top.get_update_block_host_component(x)) except: x_host = "" dot.node( x_name +"\\n@" + x_host, shape="box") for (x, y) in E: x_name = repr(x) if isinstance( x, CalleePort ) else x.__name__ if x in top._dsl.all_update_ff: x_name += "_FF" try: x_host = repr(x.get_parent_object() if isinstance( x, CalleePort ) else top.get_update_block_host_component(x)) except: x_host = "" y_name = repr(y) if isinstance( y, CalleePort ) else y.__name__ if y in top._dsl.all_update_ff: y_name += "_FF" try: y_host = repr(y.get_parent_object() if isinstance( y, CalleePort ) else top.get_update_block_host_component(y)) except: y_host = "" dot.edge( x_name+"\\n@"+x_host, y_name+"\\n@"+y_host ) dot.render( "/tmp/upblk-dag.gv", view=True )
Example #24
Source File: visualize.py From ibis with Apache License 2.0 | 5 votes |
def to_graph(expr, node_attr=None, edge_attr=None): stack = [(expr, expr._safe_name)] seen = set() g = gv.Digraph( node_attr=node_attr or DEFAULT_NODE_ATTRS, edge_attr=edge_attr or {} ) g.attr(rankdir='BT') while stack: e, ename = stack.pop() vkey = e._key, ename if vkey not in seen: seen.add(vkey) vlabel = get_label(e, argname=ename) vhash = str(hash(vkey)) g.node(vhash, label=vlabel) node = e.op() args = node.args for arg, name in zip(args, node.signature.names()): if isinstance(arg, ir.Expr): u = arg, name ukey = arg._key, name uhash = str(hash(ukey)) ulabel = get_label(arg, argname=name) g.node(uhash, label=ulabel) g.edge(uhash, vhash) stack.append(u) return g
Example #25
Source File: tiny_gp_plus.py From tiny_gp with GNU General Public License v3.0 | 5 votes |
def draw_tree(self, fname, footer): dot = [Digraph()] dot[0].attr(kw='graph', label = footer) count = [0] self.draw(dot, count) Source(dot[0], filename = fname + ".gv", format="png").render() display(Image(filename = fname + ".gv.png"))
Example #26
Source File: nist800_53viz.py From compliancelib-python with GNU General Public License v3.0 | 5 votes |
def __init__(self, id, vizformat='svg'): self.id = id # Config # self.base_path = "./" self.base_path = os.path.join(os.path.dirname(__file__), '') self.dep_dir = "data/dependencies/" self.out_dir = "" self.log_dir = "./" # graphviz image format self.vizformat = vizformat self.width = 2.5 self.height = 2.5 self.graph = functools.partial(gv.Graph, format=self.vizformat) self.digraph = functools.partial(gv.Digraph, format=self.vizformat) # Change these for a given run self.input_path = self.base_path + self.dep_dir self.output_path = self.base_path + self.out_dir # load graph self.dep_dict = self._load_graph_from_dependency_files() # load other self.resolved = [] self.nodes = [] self.edges = [] # resolve precursors self.precursor_list(self.dep_dict, self.id, self.nodes) self.precursor_controls = self.nodes
Example #27
Source File: pytorch_visualize.py From MaximumMarginGANs with MIT License | 5 votes |
def save_visualization(name, format='svg'): g = graphviz.Digraph(format=format) def sizestr(var): size = [int(i) for i in list(var.size())] return str(size) # add variable nodes for vid, var in vars.iteritems(): if isinstance(var, nn.Parameter): g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='red') elif isinstance(var, Variable): g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='lightblue') else: assert False, var.__class__ # add creator nodes for cid in func_trace: creator = funcs[cid] g.node(str(cid), label=str(creator.__class__.__name__), shape='rectangle', style='filled', fillcolor='orange') # add edges between creator and inputs for cid in func_trace: for iid in func_trace[cid]: g.edge(str(iid), str(cid)) # add edges between outputs and creators for oid in var_trace: for cid in var_trace[oid]: g.edge(str(cid), str(oid)) g.render(name)
Example #28
Source File: graph.py From ray-legacy with BSD 3-Clause "New" or "Revised" License | 5 votes |
def graph_to_graphviz(computation_graph): """ Convert the computation graph to graphviz format. Args: computation_graph [graph_pb2.CompGraph]: protocol buffer description of the computation graph Returns: Graphviz description of the computation graph """ dot = graphviz.Digraph(format="pdf") dot.node("op-root", shape="box") for (i, op) in enumerate(computation_graph.operation): if op.HasField("task"): dot.node("op" + str(i), shape="box", label=str(i) + "\n" + op.task.name.split(".")[-1]) for res in op.task.result: dot.edge("op" + str(i), str(res)) elif op.HasField("put"): dot.node("op" + str(i), shape="box", label=str(i) + "\n" + "put") dot.edge("op" + str(i), str(op.put.objectid)) elif op.HasField("get"): dot.node("op" + str(i), shape="box", label=str(i) + "\n" + "get") creator_operationid = op.creator_operationid if op.creator_operationid != 2 ** 64 - 1 else "-root" dot.edge("op" + str(creator_operationid), "op" + str(i), style="dotted", constraint="false") for arg in op.task.arg: if len(arg.serialized_arg) == 0: dot.node(str(arg.objectid)) dot.edge(str(arg.objectid), "op" + str(i)) return dot
Example #29
Source File: viz.py From scene-graph-TF-release with MIT License | 5 votes |
def draw_graph(labels, rels, cfg): u = Digraph('sg', filename='sg.gv') u.body.append('size="6,6"') u.body.append('rankdir="LR"') u.node_attr.update(style='filled') out_dict = {'ind_to_class': cfg.ind_to_class, 'ind_to_predicate': cfg.ind_to_predicate} out_dict['labels'] = labels.tolist() out_dict['relations'] = rels rels = np.array(rels) rel_inds = rels[:,:2].ravel().tolist() name_list = [] for i, l in enumerate(labels): if i in rel_inds: name = cfg.ind_to_class[l] name_suffix = 1 obj_name = name while obj_name in name_list: obj_name = name + '_' + str(name_suffix) name_suffix += 1 name_list.append(obj_name) u.node(str(i), label=obj_name, color='lightblue2') for rel in rels: edge_key = '%s_%s' % (rel[0], rel[1]) u.node(edge_key, label=cfg.ind_to_predicate[rel[2]], color='red') u.edge(str(rel[0]), edge_key) u.edge(edge_key, str(rel[1])) u.view() return out_dict
Example #30
Source File: dot.py From Global-Second-order-Pooling-Convolutional-Networks with MIT License | 5 votes |
def make_dot_from_trace(trace): """ Produces graphs of torch.jit.trace outputs Example: >>> trace, = torch.jit.trace(model, args=(x,)) >>> dot = make_dot_from_trace(trace) """ torch.onnx._optimize_trace(trace, False) graph = trace.graph() list_of_nodes = parse(graph) node_attr = dict(style='filled', shape='box', align='left', fontsize='12', ranksep='0.1', height='0.2') dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) for node in list_of_nodes: dot.node(node.name, label=node.name.replace('/', '\n')) if node.inputs: for inp in node.inputs: dot.edge(inp, node.name) resize_graph(dot) return dot