Skip to content

Commit da72079

Browse files
committed
Rework optimizer and break it out into its own module
1 parent 77886d0 commit da72079

File tree

2 files changed

+309
-255
lines changed

2 files changed

+309
-255
lines changed

Tools/jit/_optimizers.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import collections
2+
import dataclasses
3+
import pathlib
4+
import re
5+
import typing
6+
7+
branches = {}
8+
for op, nop in [
9+
("ja", "jna"),
10+
("jae", "jnae"),
11+
("jb", "jnb"),
12+
("jbe", "jnbe"),
13+
("jc", "jnc"),
14+
("jcxz", None),
15+
("je", "jne"),
16+
("jecxz", None),
17+
("jg", "jng"),
18+
("jge", "jnge"),
19+
("jl", "jnl"),
20+
("jle", "jnle"),
21+
("jo", "jno"),
22+
("jp", "jnp"),
23+
("jpe", "jpo"),
24+
("jrxz", None),
25+
("js", "jns"),
26+
("jz", "jnz"),
27+
("loop", None),
28+
("loope", None),
29+
("loopne", None),
30+
("loopnz", None),
31+
("loopz", None),
32+
]:
33+
branches[op] = nop
34+
if nop:
35+
branches[nop] = op
36+
37+
38+
def _get_branch(line: str) -> str | None:
39+
branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line)
40+
return branch and branch[2]
41+
42+
43+
def _invert_branch(line: str, label: str) -> str | None:
44+
branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line)
45+
assert branch
46+
inverted = branches.get(branch[1])
47+
if not inverted:
48+
return None
49+
line = line.replace(branch[1], inverted, 1) # XXX
50+
line = line.replace(branch[2], label, 1) # XXX
51+
return line
52+
53+
54+
def _get_jump(line: str) -> str | None:
55+
jump = re.match(r"\s*(?:rex64\s+)?jmpq?\s+\*?([\w\.]+)", line)
56+
return jump and jump[1]
57+
58+
59+
def _get_label(line: str) -> str | None:
60+
label = re.match(r"\s*([\w\.]+):", line)
61+
return label and label[1]
62+
63+
64+
def _is_return(line: str) -> bool:
65+
return re.match(r"\s*ret\s+", line) is not None
66+
67+
68+
def _is_noise(line: str) -> bool:
69+
return re.match(r"\s*([#\.]|$)", line) is not None
70+
71+
72+
@dataclasses.dataclass
73+
class _Block:
74+
label: str
75+
noise: list[str] = dataclasses.field(default_factory=list)
76+
instructions: list[str] = dataclasses.field(default_factory=list)
77+
target: typing.Self | None = None
78+
link: typing.Self | None = None
79+
fallthrough: bool = True
80+
hot: bool = False
81+
82+
def __eq__(self, other: object) -> bool:
83+
return self is other
84+
85+
def __hash__(self) -> int:
86+
return super().__hash__()
87+
88+
def resolve(self) -> typing.Self:
89+
while self.link and not self.instructions:
90+
self = self.link
91+
return self
92+
93+
94+
class _Labels(dict):
95+
def __missing__(self, key: str) -> _Block:
96+
self[key] = _Block(key)
97+
return self[key]
98+
99+
100+
@dataclasses.dataclass
101+
class Optimizer:
102+
103+
path: pathlib.Path
104+
_: dataclasses.KW_ONLY
105+
prefix: str = ""
106+
_graph: _Block = dataclasses.field(init=False)
107+
_labels: _Labels = dataclasses.field(init=False, default_factory=_Labels)
108+
109+
_re_branch: typing.ClassVar[re.Pattern[str]] # Two groups: instruction and target.
110+
_re_jump: typing.ClassVar[re.Pattern[str]] # One group: target.
111+
_re_return: typing.ClassVar[re.Pattern[str]] # No groups.
112+
113+
def __post_init__(self) -> None:
114+
text = self._preprocess(self.path.read_text())
115+
self._graph = block = self._new_block()
116+
for line in text.splitlines():
117+
if label := _get_label(line):
118+
block.link = block = self._labels[label]
119+
elif block.target or not block.fallthrough:
120+
block.link = block = self._new_block()
121+
if _is_noise(line) or _get_label(line):
122+
if block.instructions:
123+
block.link = block = self._new_block()
124+
block.noise.append(line)
125+
continue
126+
block.instructions.append(line)
127+
if target := _get_branch(line):
128+
block.target = self._labels[target]
129+
assert block.fallthrough
130+
elif target := _get_jump(line):
131+
block.target = self._labels[target]
132+
block.fallthrough = False
133+
elif _is_return(line):
134+
assert not block.target
135+
block.fallthrough = False
136+
137+
def _new_block(self, label: str | None = None) -> _Block:
138+
if not label:
139+
label = f"{self.prefix}_JIT_LABEL_{len(self._labels)}"
140+
assert label not in self._labels, label
141+
block = self._labels[label] = _Block(label, [f"{label}:"])
142+
return block
143+
144+
def _preprocess(self, text: str) -> str:
145+
return text
146+
147+
def _blocks(self) -> typing.Generator[_Block, None, None]:
148+
block = self._graph
149+
while block:
150+
yield block
151+
block = block.link
152+
153+
def _lines(self) -> typing.Generator[str, None, None]:
154+
for block in self._blocks():
155+
yield from block.noise
156+
yield from block.instructions
157+
158+
def _insert_continue_label(self) -> None:
159+
for end in reversed(list(self._blocks())):
160+
if end.instructions:
161+
break
162+
continuation = self._labels[f"{self.prefix}_JIT_CONTINUE"]
163+
continuation.noise.append(f"{continuation.label}:")
164+
end.link, continuation.link = continuation, end.link
165+
166+
def _mark_hot_blocks(self) -> None:
167+
predecessors = collections.defaultdict(set)
168+
for block in self._blocks():
169+
if block.target:
170+
predecessors[block.target].add(block)
171+
if block.fallthrough and block.link:
172+
predecessors[block.link].add(block)
173+
todo = [self._labels[f"{self.prefix}_JIT_CONTINUE"]]
174+
while todo:
175+
block = todo.pop()
176+
block.hot = True
177+
todo.extend(
178+
predecessor
179+
for predecessor in predecessors[block]
180+
if not predecessor.hot
181+
)
182+
183+
def _invert_hot_branches(self) -> None:
184+
for block in self._blocks():
185+
if (
186+
block.fallthrough
187+
and block.target
188+
and block.link
189+
and block.target.hot
190+
and not block.link.hot
191+
):
192+
# Turn...
193+
# branch hot
194+
# ...into..
195+
# opposite-branch ._JIT_LABEL_N
196+
# jmp hot
197+
# ._JIT_LABEL_N:
198+
label_block = self._new_block()
199+
inverted = _invert_branch(block.instructions[-1], label_block.label)
200+
if inverted is None:
201+
continue
202+
jump_block = self._new_block()
203+
jump_block.instructions.append(f"\tjmp\t{block.target.label}")
204+
jump_block.target = block.target
205+
jump_block.fallthrough = False
206+
block.instructions[-1] = inverted
207+
block.target = label_block
208+
label_block.link = block.link
209+
jump_block.link = label_block
210+
block.link = jump_block
211+
212+
def _thread_jumps(self) -> None:
213+
for block in self._blocks():
214+
while block.target:
215+
label = block.target.label
216+
target = block.target.resolve()
217+
if (
218+
not target.fallthrough
219+
and target.target
220+
and len(target.instructions) == 1
221+
):
222+
block.instructions[-1] = block.instructions[-1].replace(
223+
label, target.target.label
224+
) # XXX
225+
block.target = target.target
226+
else:
227+
break
228+
229+
def _remove_dead_code(self) -> None:
230+
reachable = set()
231+
todo = [self._graph]
232+
while todo:
233+
block = todo.pop()
234+
reachable.add(block)
235+
if block.target and block.target not in reachable:
236+
todo.append(block.target)
237+
if block.fallthrough and block.link and block.link not in reachable:
238+
todo.append(block.link)
239+
for block in self._blocks():
240+
if block not in reachable:
241+
block.instructions.clear()
242+
243+
def _remove_redundant_jumps(self) -> None:
244+
for block in self._blocks():
245+
if (
246+
block.target
247+
and block.link
248+
and block.target.resolve() is block.link.resolve()
249+
):
250+
block.target = None
251+
block.fallthrough = True
252+
block.instructions.pop()
253+
254+
def _remove_unused_labels(self) -> None:
255+
used = set()
256+
for block in self._blocks():
257+
if block.target:
258+
used.add(block.target)
259+
for block in self._blocks():
260+
if block not in used and block.label.startswith(
261+
f"{self.prefix}_JIT_LABEL_"
262+
):
263+
del block.noise[0]
264+
265+
def run(self) -> None:
266+
self._insert_continue_label()
267+
self._mark_hot_blocks()
268+
self._invert_hot_branches()
269+
self._thread_jumps()
270+
self._remove_dead_code()
271+
self._remove_redundant_jumps()
272+
self._remove_unused_labels()
273+
self.path.write_text("\n".join(self._lines()))
274+
275+
276+
class OptimizerX86(Optimizer):
277+
278+
_re_branch = re.compile(
279+
rf"\s*(?P<instruction>{'|'.join(branches)})\s+(?P<target>[\w\.]+)"
280+
)
281+
_re_jump = re.compile(r"\s*jmp\s+(?P<target>[\w\.]+)")
282+
_re_return = re.compile(r"\s*ret\b")
283+
284+
285+
class OptimizerX86Windows(OptimizerX86):
286+
287+
def _preprocess(self, text: str) -> str:
288+
text = super()._preprocess(text)
289+
far_indirect_jump = (
290+
rf"rex64\s+jmpq\s+\*__imp_(?P<target>{self.prefix}_JIT_\w+)\(%rip\)"
291+
)
292+
return re.sub(far_indirect_jump, r"jmp\t\g<target>", text)

0 commit comments

Comments
 (0)