1+ from ..visitor import ASTVisitor
2+
3+ from collections import defaultdict
4+
5+ class ControlFlowVisitor (ASTVisitor ):
6+
7+ def __init__ (self , graph ):
8+ super ().__init__ ()
9+ self .graph = graph
10+ self ._last_stmts = tuple ()
11+
12+ self ._returns_from = []
13+ self ._continue_from = defaultdict (list )
14+ self ._break_from = defaultdict (list )
15+
16+ def _add_next (self , stmt_node ):
17+ for last_stmt_node in self ._last_stmts :
18+ self .graph .add_relation (last_stmt_node , stmt_node , "controlflow" )
19+ self ._last_stmts = (stmt_node ,)
20+
21+ def _reset_last_stmts (self , reset_target ):
22+ last_stmts = self ._last_stmts
23+ self ._last_stmts = (reset_target ,)
24+ return last_stmts
25+
26+ def visit_block (self , node ):
27+
28+ for stmt in node .children :
29+ self .walk (stmt )
30+
31+ return False
32+
33+ # Methods --------------------------------------------------------
34+
35+ def visit_method_declaration (self , node ):
36+ outside_last , self ._last_stmts = self ._last_stmts , (node ,)
37+ outside_returns = self ._returns_from
38+ self ._returns_from = []
39+
40+ self .walk (
41+ node .child_by_field_name ("body" )
42+ )
43+
44+ for stmt in self ._last_stmts :
45+ self .graph .add_relation (stmt , node , "return_from" )
46+
47+ for stmt in self ._returns_from :
48+ self .graph .add_relation (stmt , node , "return_from" )
49+
50+ self ._returns_from = outside_returns
51+ self ._last_stmts = outside_last
52+ return False
53+
54+ def visit_return_statement (self , node ):
55+ self ._add_next (node )
56+ self ._returns_from .append (node )
57+ self ._last_stmts = tuple ()
58+ return False
59+
60+ # Labeled statements --------------------------------
61+
62+ def visit_labeled_statement (self , node ):
63+ name_node , _ , body = node .children
64+ name = self .graph .add_or_get_node (name_node ) # has to be a token
65+ name = name .token .text
66+
67+ self .walk (body )
68+
69+ current_last = self ._last_stmts
70+ self ._last_stmts = tuple (self ._continue_from [name ])
71+ self ._add_next (body )
72+ self ._continue_from [name ] = []
73+
74+ self ._last_stmts = current_last + tuple (self ._break_from [name ])
75+ self ._break_from [name ] = []
76+ return False
77+
78+ def visit_break_statement (self , node ):
79+ self ._add_next (node )
80+
81+ jump_label = "__LOOP__"
82+ if node .child_count > 2 :
83+ name_node = node .children [1 ]
84+ name_token = self .graph .add_or_get_node (name_node )
85+ jump_label = name_token .token .text
86+
87+ self ._break_from [jump_label ].append (node )
88+ self ._last_stmts = tuple ()
89+ return False
90+
91+ def visit_continue_statement (self , node ):
92+ self ._add_next (node )
93+
94+ jump_label = "__LOOP__"
95+ if node .child_count > 2 :
96+ name_node = node .children [1 ]
97+ name_token = self .graph .add_or_get_node (name_node )
98+ jump_label = name_token .token .text
99+
100+ self ._continue_from [jump_label ].append (node )
101+ self ._last_stmts = tuple ()
102+ return False
103+
104+ # Control structures --------------------------------
105+
106+ def visit_if_statement (self , node ):
107+ self ._add_next (node )
108+
109+ # Consequences
110+ self .walk (node .child_by_field_name ("consequence" ))
111+ left_last_stmts = self ._reset_last_stmts (node )
112+
113+ # Alternative
114+ self .walk (node .child_by_field_name ("alternative" ))
115+ right_last_stmts = self ._reset_last_stmts (node )
116+
117+ self ._last_stmts = left_last_stmts + right_last_stmts
118+ return False
119+
120+ def visit_for_statement (self , node ):
121+
122+ jump_label = "__LOOP__"
123+ prev_break , prev_continue = self ._break_from [jump_label ], self ._continue_from [jump_label ]
124+ self ._break_from [jump_label ], self ._continue_from [jump_label ] = [], []
125+
126+ self ._add_next (node )
127+ self .walk (node .child_by_field_name ("body" ))
128+ self ._last_stmts += tuple (self ._continue_from [jump_label ])
129+ self ._add_next (node )
130+
131+ self ._last_stmts += tuple (self ._break_from [jump_label ])
132+
133+ self ._break_from [jump_label ], self ._continue_from [jump_label ] = prev_break , prev_continue
134+ return False
135+
136+ def visit_while_statement (self , node ):
137+
138+ jump_label = "__LOOP__"
139+ prev_break , prev_continue = self ._break_from [jump_label ], self ._continue_from [jump_label ]
140+ self ._break_from [jump_label ], self ._continue_from [jump_label ] = [], []
141+
142+ self ._add_next (node )
143+ self .walk (node .child_by_field_name ("body" ))
144+ self ._last_stmts += tuple (self ._continue_from [jump_label ])
145+ self ._add_next (node )
146+
147+ self ._last_stmts += tuple (self ._break_from [jump_label ])
148+
149+ self ._break_from [jump_label ], self ._continue_from [jump_label ] = prev_break , prev_continue
150+ return False
151+
152+ def visit_do_statement (self , node ):
153+
154+ jump_label = "__LOOP__"
155+ prev_break , prev_continue = self ._break_from [jump_label ], self ._continue_from [jump_label ]
156+ self ._break_from [jump_label ], self ._continue_from [jump_label ] = [], []
157+
158+ self ._add_next (node )
159+ self .walk (node .child_by_field_name ("body" ))
160+ self ._last_stmts += tuple (self ._continue_from [jump_label ])
161+ self ._add_next (node )
162+
163+ self ._last_stmts += tuple (self ._break_from [jump_label ])
164+
165+ self ._break_from [jump_label ], self ._continue_from [jump_label ] = prev_break , prev_continue
166+ return False
167+
168+
169+ def visit_try_statement (self , node ):
170+ self ._add_next (node )
171+ starting_stmt = self ._last_stmts
172+
173+ self .walk (node .child_by_field_name ("body" ))
174+
175+ exception_starting_stmts = self ._last_stmts + starting_stmt
176+ self ._last_stmts = exception_starting_stmts
177+ out_last_stmts = tuple ()
178+
179+ finally_clauses = []
180+ for possible_exception in node .children :
181+ if possible_exception .type == "catch_clause" :
182+ self .walk (possible_exception )
183+ out_last_stmts += self ._last_stmts
184+ self ._last_stmts = exception_starting_stmts
185+ if possible_exception .type == "finally_clause" :
186+ finally_clauses .append (possible_exception )
187+
188+ self ._last_stmts += out_last_stmts
189+
190+ for finally_clause in finally_clauses :
191+ self .walk (finally_clause )
192+
193+ return False
194+
195+
196+ def visit (self , node ):
197+ # All statement node type end with statement
198+ # in tree-sitter
199+ # Therefore, we can savely do this hack.
200+ if node .type .endswith ("statement" ):
201+ self ._add_next (node )
202+ return False
203+ return True
0 commit comments