Skip to content

Commit 05e56e0

Browse files
committed
Add support for analysis of Java methods
1 parent e7e0f1c commit 05e56e0

File tree

7 files changed

+1163
-12
lines changed

7 files changed

+1163
-12
lines changed

code_graph/__init__.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,36 @@
22

33
from .graph import CodeGraph
44

5-
from .ast import ASTRelationVisitor
6-
from .cfg import ControlFlowVisitor
7-
from .dataflow import DataFlowVisitor
5+
from .pylang import pylang_analyses
6+
from .javalang import javalang_analyses
87

98

10-
GRAPH_ANALYSES = {
11-
"ast": ASTRelationVisitor,
12-
"cfg": ControlFlowVisitor,
13-
"dataflow": DataFlowVisitor,
14-
}
15-
169
def codegraph(source_code, lang = "guess", analyses = None, **kwargs):
1710
tokens = ctok.tokenize(source_code, lang = lang, **kwargs)
1811
root_node = _root_node(tokens)
1912

13+
graph_analyses = load_lang_analyses(tokens[0].config.lang)
14+
2015
if analyses is None:
21-
analyses = GRAPH_ANALYSES.keys()
16+
analyses = graph_analyses.keys()
2217
else:
23-
assert all(a in GRAPH_ANALYSES for a in analyses), \
18+
assert all(a in graph_analyses.keys() for a in analyses), \
2419
"Not all analyses are supported. Available analyses are: %s" % ", ".join(GRAPH_ANALYSES.keys())
2520

2621
graph = CodeGraph(root_node, tokens, lang = lang)
2722

2823
for analysis in analyses:
29-
analysis_visitor = GRAPH_ANALYSES[analysis]
24+
analysis_visitor = graph_analyses[analysis]
3025
analysis_visitor(graph)(root_node)
3126

3227
return graph
3328

29+
30+
def load_lang_analyses(lang):
31+
if lang == 'python': return pylang_analyses()
32+
if lang == 'java' : return javalang_analyses()
33+
34+
raise NotImplementedError("Language %s is not supported" % lang)
3435

3536

3637
# Helper methods --------------------------------

code_graph/javalang/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from ..ast import ASTRelationVisitor
2+
3+
from .cfg import ControlFlowVisitor
4+
from .dataflow import DataFlowVisitor
5+
6+
def javalang_analyses():
7+
return {
8+
"ast": ASTRelationVisitor,
9+
"cfg": ControlFlowVisitor,
10+
"dataflow": DataFlowVisitor,
11+
}

code_graph/javalang/cfg.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from ..visitor import ASTVisitor
2+
3+
from collections import defaultdict
4+
5+
class ControlFlowVisitor(ASTVisitor):
6+
7+
def __init__(self, graph):
8+
super().__init__()
9+
self.graph = graph
10+
self._last_stmts = tuple()
11+
12+
self._returns_from = []
13+
self._continue_from = defaultdict(list)
14+
self._break_from = defaultdict(list)
15+
16+
def _add_next(self, stmt_node):
17+
for last_stmt_node in self._last_stmts:
18+
self.graph.add_relation(last_stmt_node, stmt_node, "controlflow")
19+
self._last_stmts = (stmt_node,)
20+
21+
def _reset_last_stmts(self, reset_target):
22+
last_stmts = self._last_stmts
23+
self._last_stmts = (reset_target,)
24+
return last_stmts
25+
26+
def visit_block(self, node):
27+
28+
for stmt in node.children:
29+
self.walk(stmt)
30+
31+
return False
32+
33+
# Methods --------------------------------------------------------
34+
35+
def visit_method_declaration(self, node):
36+
outside_last, self._last_stmts = self._last_stmts, (node,)
37+
outside_returns = self._returns_from
38+
self._returns_from = []
39+
40+
self.walk(
41+
node.child_by_field_name("body")
42+
)
43+
44+
for stmt in self._last_stmts:
45+
self.graph.add_relation(stmt, node, "return_from")
46+
47+
for stmt in self._returns_from:
48+
self.graph.add_relation(stmt, node, "return_from")
49+
50+
self._returns_from = outside_returns
51+
self._last_stmts = outside_last
52+
return False
53+
54+
def visit_return_statement(self, node):
55+
self._add_next(node)
56+
self._returns_from.append(node)
57+
self._last_stmts = tuple()
58+
return False
59+
60+
# Labeled statements --------------------------------
61+
62+
def visit_labeled_statement(self, node):
63+
name_node, _, body = node.children
64+
name = self.graph.add_or_get_node(name_node) # has to be a token
65+
name = name.token.text
66+
67+
self.walk(body)
68+
69+
current_last = self._last_stmts
70+
self._last_stmts = tuple(self._continue_from[name])
71+
self._add_next(body)
72+
self._continue_from[name] = []
73+
74+
self._last_stmts = current_last + tuple(self._break_from[name])
75+
self._break_from[name] = []
76+
return False
77+
78+
def visit_break_statement(self, node):
79+
self._add_next(node)
80+
81+
jump_label = "__LOOP__"
82+
if node.child_count > 2:
83+
name_node = node.children[1]
84+
name_token = self.graph.add_or_get_node(name_node)
85+
jump_label = name_token.token.text
86+
87+
self._break_from[jump_label].append(node)
88+
self._last_stmts = tuple()
89+
return False
90+
91+
def visit_continue_statement(self, node):
92+
self._add_next(node)
93+
94+
jump_label = "__LOOP__"
95+
if node.child_count > 2:
96+
name_node = node.children[1]
97+
name_token = self.graph.add_or_get_node(name_node)
98+
jump_label = name_token.token.text
99+
100+
self._continue_from[jump_label].append(node)
101+
self._last_stmts = tuple()
102+
return False
103+
104+
# Control structures --------------------------------
105+
106+
def visit_if_statement(self, node):
107+
self._add_next(node)
108+
109+
# Consequences
110+
self.walk(node.child_by_field_name("consequence"))
111+
left_last_stmts = self._reset_last_stmts(node)
112+
113+
# Alternative
114+
self.walk(node.child_by_field_name("alternative"))
115+
right_last_stmts = self._reset_last_stmts(node)
116+
117+
self._last_stmts = left_last_stmts + right_last_stmts
118+
return False
119+
120+
def visit_for_statement(self, node):
121+
122+
jump_label = "__LOOP__"
123+
prev_break, prev_continue = self._break_from[jump_label], self._continue_from[jump_label]
124+
self._break_from[jump_label], self._continue_from[jump_label] = [], []
125+
126+
self._add_next(node)
127+
self.walk(node.child_by_field_name("body"))
128+
self._last_stmts += tuple(self._continue_from[jump_label])
129+
self._add_next(node)
130+
131+
self._last_stmts += tuple(self._break_from[jump_label])
132+
133+
self._break_from[jump_label], self._continue_from[jump_label] = prev_break, prev_continue
134+
return False
135+
136+
def visit_while_statement(self, node):
137+
138+
jump_label = "__LOOP__"
139+
prev_break, prev_continue = self._break_from[jump_label], self._continue_from[jump_label]
140+
self._break_from[jump_label], self._continue_from[jump_label] = [], []
141+
142+
self._add_next(node)
143+
self.walk(node.child_by_field_name("body"))
144+
self._last_stmts += tuple(self._continue_from[jump_label])
145+
self._add_next(node)
146+
147+
self._last_stmts += tuple(self._break_from[jump_label])
148+
149+
self._break_from[jump_label], self._continue_from[jump_label] = prev_break, prev_continue
150+
return False
151+
152+
def visit_do_statement(self, node):
153+
154+
jump_label = "__LOOP__"
155+
prev_break, prev_continue = self._break_from[jump_label], self._continue_from[jump_label]
156+
self._break_from[jump_label], self._continue_from[jump_label] = [], []
157+
158+
self._add_next(node)
159+
self.walk(node.child_by_field_name("body"))
160+
self._last_stmts += tuple(self._continue_from[jump_label])
161+
self._add_next(node)
162+
163+
self._last_stmts += tuple(self._break_from[jump_label])
164+
165+
self._break_from[jump_label], self._continue_from[jump_label] = prev_break, prev_continue
166+
return False
167+
168+
169+
def visit_try_statement(self, node):
170+
self._add_next(node)
171+
starting_stmt = self._last_stmts
172+
173+
self.walk(node.child_by_field_name("body"))
174+
175+
exception_starting_stmts = self._last_stmts + starting_stmt
176+
self._last_stmts = exception_starting_stmts
177+
out_last_stmts = tuple()
178+
179+
finally_clauses = []
180+
for possible_exception in node.children:
181+
if possible_exception.type == "catch_clause":
182+
self.walk(possible_exception)
183+
out_last_stmts += self._last_stmts
184+
self._last_stmts = exception_starting_stmts
185+
if possible_exception.type == "finally_clause":
186+
finally_clauses.append(possible_exception)
187+
188+
self._last_stmts += out_last_stmts
189+
190+
for finally_clause in finally_clauses:
191+
self.walk(finally_clause)
192+
193+
return False
194+
195+
196+
def visit(self, node):
197+
# All statement node type end with statement
198+
# in tree-sitter
199+
# Therefore, we can savely do this hack.
200+
if node.type.endswith("statement"):
201+
self._add_next(node)
202+
return False
203+
return True

0 commit comments

Comments
 (0)