1+ from .visitor import ASTVisitor
2+
3+ class ControlFlowVisitor (ASTVisitor ):
4+
5+ def __init__ (self , graph ):
6+ super ().__init__ ()
7+ self .graph = graph
8+ self ._last_stmts = tuple ()
9+
10+ self ._break_from = []
11+ self ._continue_from = []
12+ self ._returns_from = []
13+ self ._yields_from = []
14+
15+
16+ def _add_next (self , stmt_node ):
17+ for last_stmt_node in self ._last_stmts :
18+ self .graph .add_relation (last_stmt_node , stmt_node , "controlflow" )
19+ self ._last_stmts = (stmt_node ,)
20+
21+ def _reset_last_stmts (self , reset_target ):
22+ last_stmts = self ._last_stmts
23+ self ._last_stmts = (reset_target ,)
24+ return last_stmts
25+
26+ def visit_block (self , node ):
27+
28+ for stmt in node .children :
29+ self .walk (stmt )
30+
31+ return False
32+
33+ def visit_function_definition (self , node ):
34+ outside_last , self ._last_stmts = self ._last_stmts , (node ,)
35+ outside_returns , outside_yields = self ._returns_from , self ._yields_from
36+ self ._returns_from , self ._yields_from = [], []
37+
38+ self .walk (
39+ node .child_by_field_name ("body" )
40+ )
41+
42+ for stmt in self ._last_stmts :
43+ self .graph .add_relation (stmt , node , "return_from" )
44+
45+ for stmt in self ._returns_from :
46+ self .graph .add_relation (stmt , node , "return_from" )
47+
48+ for stmt in self ._yields_from :
49+ self .graph .add_relation (stmt , node , "yield_from" )
50+
51+ self ._returns_from , self ._yields_from = outside_returns , outside_yields
52+ self ._last_stmts = outside_last
53+ return False
54+
55+ def visit_if_statement (self , node ):
56+ self ._add_next (node )
57+
58+ # Consequences
59+ self .walk (node .child_by_field_name ("consequence" ))
60+ left_last_stmts = self ._reset_last_stmts (node )
61+
62+ # Alternative
63+ self .walk (node .child_by_field_name ("alternative" ))
64+ right_last_stmts = self ._reset_last_stmts (node )
65+
66+ self ._last_stmts = left_last_stmts + right_last_stmts
67+ return False
68+
69+ def visit_return_statement (self , node ):
70+ self ._add_next (node )
71+ self ._returns_from .append (node )
72+ self ._last_stmts = tuple ()
73+ return False
74+
75+ def visit_yield_statement (self , node ):
76+ self ._add_next (node )
77+ self ._yields_from .append (node )
78+ self ._last_stmts = tuple ()
79+ return False
80+
81+ def visit_break_statement (self , node ):
82+ self ._add_next (node )
83+ self ._break_from .append (node )
84+ self ._last_stmts = tuple ()
85+ return False
86+
87+ def visit_continue_statement (self , node ):
88+ self ._add_next (node )
89+ self ._continue_from .append (node )
90+ self ._last_stmts = tuple ()
91+ return False
92+
93+ def visit_for_statement (self , node ):
94+
95+ prev_break , prev_continue = self ._break_from , self ._continue_from
96+ self ._break_from , self ._continue_from = [], []
97+
98+ self ._add_next (node )
99+ self .walk (node .child_by_field_name ("body" ))
100+ self ._last_stmts += tuple (self ._continue_from )
101+ self ._add_next (node )
102+
103+ self .walk (node .child_by_field_name ("alternative" ))
104+
105+ self ._last_stmts += tuple (self ._break_from )
106+
107+ self ._break_from , self ._continue_from = prev_break , prev_continue
108+ return False
109+
110+ def visit_while_statement (self , node ):
111+
112+ prev_break , prev_continue = self ._break_from , self ._continue_from
113+ self ._break_from , self ._continue_from = [], []
114+
115+ self ._add_next (node )
116+ self .walk (node .child_by_field_name ("body" ))
117+ self ._last_stmts += tuple (self ._continue_from )
118+ self ._add_next (node )
119+
120+ self .walk (node .child_by_field_name ("alternative" ))
121+
122+ self ._last_stmts += tuple (self ._break_from )
123+
124+ self ._break_from , self ._continue_from = prev_break , prev_continue
125+ return False
126+
127+ def visit_try_statement (self , node ):
128+ self ._add_next (node )
129+ starting_stmt = self ._last_stmts
130+
131+ self .walk (node .child_by_field_name ("body" ))
132+ self .walk (node .child_by_field_name ("alternative" ))
133+
134+ exception_starting_stmts = self ._last_stmts + starting_stmt
135+ self ._last_stmts = exception_starting_stmts
136+ out_last_stmts = tuple ()
137+
138+ finally_clauses = []
139+ for possible_exception in node .children :
140+ if possible_exception .type == "except_clause" :
141+ self .walk (possible_exception )
142+ out_last_stmts += self ._last_stmts
143+ self ._last_stmts = exception_starting_stmts
144+ if possible_exception .type == "finally_clause" :
145+ finally_clauses .append (possible_exception )
146+
147+ self ._last_stmts += out_last_stmts
148+
149+ for finally_clause in finally_clauses :
150+ self .walk (finally_clause )
151+
152+ return False
153+
154+ def visit (self , node ):
155+ # All statement node type end with statement
156+ # in tree-sitter
157+ # Therefore, we can savely do this hack.
158+ if node .type .endswith ("statement" ):
159+ self ._add_next (node )
160+ return False
161+ return True
0 commit comments