Skip to content

Commit ee97e87

Browse files
committed
More cleanup, don't add jumps
1 parent d2c9ae9 commit ee97e87

File tree

2 files changed

+79
-171
lines changed

2 files changed

+79
-171
lines changed

Tools/jit/_optimizers.py

Lines changed: 61 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
import typing
66

7-
_RE_NEVER_MATCH = re.compile(r"(?!)")
7+
_RE_NEVER_MATCH = re.compile(r"(?!)") # Same as saying "not string.startswith('')".
88

99
_X86_BRANCHES = {
1010
"ja": "jna",
@@ -36,7 +36,7 @@
3636

3737
@dataclasses.dataclass
3838
class _Block:
39-
label: str
39+
label: str | None = None
4040
noise: list[str] = dataclasses.field(default_factory=list)
4141
instructions: list[str] = dataclasses.field(default_factory=list)
4242
target: typing.Self | None = None
@@ -71,94 +71,61 @@ class Optimizer:
7171
)
7272
_re_jump: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH # One group: target.
7373
_re_label: typing.ClassVar[re.Pattern[str]] = re.compile(
74-
r'\s*(?P<label>[\w"$.?@]+):'
75-
) # One group: label.
74+
r"\s*(?P<label>[\w.]+):" # One group: label.
75+
)
7676
_re_noise: typing.ClassVar[re.Pattern[str]] = re.compile(
77-
r"\s*(?:[#.]|$)"
78-
) # No groups.
77+
r"\s*(?:[#.]|$)" # No groups.
78+
)
7979
_re_return: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH # No groups.
8080

8181
def __post_init__(self) -> None:
8282
text = self._preprocess(self.path.read_text())
83-
self._graph = block = self._new_block()
83+
self._graph = block = _Block()
8484
for line in text.splitlines():
85-
if label := self._parse_label(line):
86-
block.link = block = self._lookup_label(label)
85+
if match := self._re_label.match(line):
86+
block.link = block = self._lookup_label(match["label"])
8787
block.noise.append(line)
8888
continue
89-
elif block.target or not block.fallthrough:
90-
block.link = block = self._new_block()
91-
if self._parse_noise(line):
89+
if self._re_noise.match(line):
9290
if block.instructions:
93-
block.link = block = self._new_block()
91+
block.link = block = _Block()
9492
block.noise.append(line)
9593
continue
94+
if block.target or not block.fallthrough:
95+
block.link = block = _Block()
9696
block.instructions.append(line)
97-
if target := self._parse_branch(line):
98-
block.target = self._lookup_label(target)
97+
if match := self._re_branch.match(line):
98+
block.target = self._lookup_label(match["target"])
9999
assert block.fallthrough
100-
elif target := self._parse_jump(line):
101-
block.target = self._lookup_label(target)
100+
elif match := self._re_jump.match(line):
101+
block.target = self._lookup_label(match["target"])
102102
block.fallthrough = False
103-
elif self._parse_return(line):
103+
elif self._re_return.match(line):
104104
assert not block.target
105105
block.fallthrough = False
106106

107-
@classmethod
108-
def _parse_branch(cls, line: str) -> str | None:
109-
branch = cls._re_branch.match(line)
110-
return branch and branch["target"]
111-
112-
@classmethod
113-
def _parse_jump(cls, line: str) -> str | None:
114-
jump = cls._re_jump.match(line)
115-
return jump and jump["target"]
116-
117-
@classmethod
118-
def _parse_label(cls, line: str) -> str | None:
119-
label = cls._re_label.match(line)
120-
return label and label["label"]
121-
122-
@classmethod
123-
def _parse_noise(cls, line: str) -> bool:
124-
return cls._re_noise.match(line) is not None
125-
126-
@classmethod
127-
def _parse_return(cls, line: str) -> bool:
128-
return cls._re_return.match(line) is not None
129-
130107
@classmethod
131108
def _invert_branch(cls, line: str, target: str) -> str | None:
132-
branch = cls._re_branch.match(line)
133-
assert branch
134-
inverted = cls._branches.get(branch["instruction"])
109+
match = cls._re_branch.match(line)
110+
assert match
111+
inverted = cls._branches.get(match["instruction"])
135112
if not inverted:
136113
return None
137-
(a, b), (c, d) = branch.span("instruction"), branch.span("target")
114+
(a, b), (c, d) = match.span("instruction"), match.span("target")
138115
return "".join([line[:a], inverted, line[b:c], target, line[d:]])
139116

140117
@classmethod
141-
def _thread_jump(cls, line: str, target: str) -> str:
142-
jump = cls._re_jump.match(line) or cls._re_branch.match(line)
143-
assert jump
144-
a, b = jump.span("target")
118+
def _update_jump(cls, line: str, target: str) -> str:
119+
match = cls._re_jump.match(line)
120+
assert match
121+
a, b = match.span("target")
145122
return "".join([line[:a], target, line[b:]])
146123

147-
@staticmethod
148-
def _create_jump(target: str) -> str | None:
149-
return None
150-
151124
def _lookup_label(self, label: str) -> _Block:
152125
if label not in self._labels:
153126
self._labels[label] = _Block(label)
154127
return self._labels[label]
155128

156-
def _new_block(self) -> _Block:
157-
label = f"{self.prefix}_JIT_LABEL_{len(self._labels)}"
158-
block = self._lookup_label(label)
159-
block.noise.append(f"{label}:")
160-
return block
161-
162129
@staticmethod
163130
def _preprocess(text: str) -> str:
164131
return text
@@ -169,28 +136,31 @@ def _blocks(self) -> typing.Generator[_Block, None, None]:
169136
yield block
170137
block = block.link
171138

172-
def _lines(self) -> typing.Generator[str, None, None]:
139+
def _body(self) -> str:
140+
lines = []
173141
for block in self._blocks():
174-
yield from block.noise
175-
yield from block.instructions
142+
lines.extend(block.noise)
143+
lines.extend(block.instructions)
144+
return "\n".join(lines)
176145

177146
def _insert_continue_label(self) -> None:
178147
for end in reversed(list(self._blocks())):
179148
if end.instructions:
180149
break
181-
align = self._new_block()
150+
align = _Block()
182151
align.noise.append(f"\t.balign\t{self._alignment}")
183152
continuation = self._lookup_label(f"{self.prefix}_JIT_CONTINUE")
153+
assert continuation.label
184154
continuation.noise.append(f"{continuation.label}:")
185155
end.link, align.link, continuation.link = align, continuation, end.link
186156

187157
def _mark_hot_blocks(self) -> None:
188-
predecessors = collections.defaultdict(set)
158+
predecessors = collections.defaultdict(list)
189159
for block in self._blocks():
190160
if block.target:
191-
predecessors[block.target].add(block)
161+
predecessors[block.target].append(block)
192162
if block.fallthrough and block.link:
193-
predecessors[block.link].add(block)
163+
predecessors[block.link].append(block)
194164
todo = [self._lookup_label(f"{self.prefix}_JIT_CONTINUE")]
195165
while todo:
196166
block = todo.pop()
@@ -202,64 +172,33 @@ def _mark_hot_blocks(self) -> None:
202172
)
203173

204174
def _invert_hot_branches(self) -> None:
205-
for block in self._blocks():
175+
for branch in self._blocks():
176+
jump = branch.link
206177
if (
207-
block.fallthrough
208-
and block.target
209-
and block.target.hot
210-
and block.link
211-
and not block.link.hot
178+
# block ends with a branch to hot code...
179+
branch.target
180+
and branch.fallthrough
181+
and branch.target.hot
182+
# ...followed by a jump to cold code:
183+
and jump
184+
and jump.target
185+
and not jump.fallthrough
186+
and not jump.target.hot
187+
and len(jump.instructions) == 1
212188
):
213-
label_block = self._new_block()
214-
branch = block.instructions[-1]
215-
inverted = self._invert_branch(branch, label_block.label)
216-
jump = self._create_jump(block.target.label)
217-
if inverted is None or jump is None:
189+
assert jump.target.label
190+
assert branch.target.label
191+
inverted = self._invert_branch(
192+
branch.instructions[-1], jump.target.label
193+
)
194+
if inverted is None:
218195
continue
219-
jump_block = self._new_block()
220-
jump_block.instructions.append(jump)
221-
jump_block.target = block.target
222-
jump_block.fallthrough = False
223-
block.instructions[-1] = inverted
224-
block.target = label_block
225-
block.link, jump_block.link, label_block.link = (
226-
jump_block,
227-
label_block,
228-
block.link,
196+
branch.instructions[-1] = inverted
197+
jump.instructions[-1] = self._update_jump(
198+
jump.instructions[-1], branch.target.label
229199
)
230-
231-
def _thread_jumps(self) -> None:
232-
for block in self._blocks():
233-
while block.target:
234-
target = block.target.resolve()
235-
if (
236-
not target.fallthrough
237-
and target.target
238-
and len(target.instructions) == 1
239-
):
240-
jump = block.instructions[-1]
241-
block.instructions[-1] = self._thread_jump(
242-
jump, target.target.label
243-
)
244-
block.target = target.target
245-
else:
246-
break
247-
248-
def _remove_dead_code(self) -> None:
249-
reachable = set()
250-
todo = [self._graph]
251-
while todo:
252-
block = todo.pop()
253-
reachable.add(block)
254-
if block.target and block.target not in reachable:
255-
todo.append(block.target)
256-
if block.fallthrough and block.link and block.link not in reachable:
257-
todo.append(block.link)
258-
for block in self._blocks():
259-
if block not in reachable:
260-
block.target = None
261-
block.fallthrough = True
262-
block.instructions.clear()
200+
branch.target, jump.target = jump.target, branch.target
201+
jump.hot = True
263202

264203
def _remove_redundant_jumps(self) -> None:
265204
for block in self._blocks():
@@ -272,26 +211,12 @@ def _remove_redundant_jumps(self) -> None:
272211
block.fallthrough = True
273212
block.instructions.pop()
274213

275-
def _remove_unused_labels(self) -> None:
276-
used = set()
277-
for block in self._blocks():
278-
if block.target:
279-
used.add(block.target)
280-
for block in self._blocks():
281-
if block not in used and block.label.startswith(
282-
f"{self.prefix}_JIT_LABEL_"
283-
):
284-
del block.noise[0]
285-
286214
def run(self) -> None:
287215
self._insert_continue_label()
288216
self._mark_hot_blocks()
289217
self._invert_hot_branches()
290-
self._thread_jumps()
291-
# self._remove_dead_code() # XXX: Need calls to reason about this!
292218
self._remove_redundant_jumps()
293-
self._remove_unused_labels()
294-
self.path.write_text("\n".join(self._lines()))
219+
self.path.write_text(self._body())
295220

296221

297222
class OptimizerAArch64(Optimizer):
@@ -308,15 +233,12 @@ class OptimizerX86(Optimizer):
308233
_re_jump = re.compile(r"\s*jmp\s+(?P<target>[\w.]+)")
309234
_re_return = re.compile(r"\s*ret\b")
310235

311-
@staticmethod
312-
def _create_jump(target: str) -> str:
313-
return f"\tjmp\t{target}"
314-
315236

316237
class OptimizerX86Windows(OptimizerX86):
317238

318239
def _preprocess(self, text: str) -> str:
319240
text = super()._preprocess(text)
241+
# rex64 jumpq *__imp__JIT_CONTINUE(%rip) -> jmp _JIT_CONTINUE
320242
far_indirect_jump = (
321243
rf"rex64\s+jmpq\s+\*__imp_(?P<target>{self.prefix}_JIT_\w+)\(%rip\)"
322244
)

Tools/jit/_targets.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -531,53 +531,39 @@ def get_target(host: str) -> _COFF | _ELF | _MachO:
531531
target: _COFF | _ELF | _MachO
532532
if re.fullmatch(r"aarch64-apple-darwin.*", host):
533533
condition = "defined(__aarch64__) && defined(__APPLE__)"
534-
target = _MachO(
535-
host, condition, optimizer=_optimizers.OptimizerAArch64, prefix="_"
536-
)
534+
optimizer = _optimizers.OptimizerAArch64
535+
target = _MachO(host, condition, optimizer=optimizer, prefix="_")
537536
elif re.fullmatch(r"aarch64-pc-windows-msvc", host):
538537
args = ["-fms-runtime-lib=dll", "-fplt"]
539538
condition = "defined(_M_ARM64)"
540-
target = _COFF(
541-
host, condition, args=args, optimizer=_optimizers.OptimizerAArch64
542-
)
539+
optimizer = _optimizers.OptimizerAArch64
540+
target = _COFF(host, condition, args=args, optimizer=optimizer)
543541
elif re.fullmatch(r"aarch64-.*-linux-gnu", host):
544-
args = [
545-
"-fpic",
546-
# On aarch64 Linux, intrinsics were being emitted and this flag
547-
# was required to disable them.
548-
"-mno-outline-atomics",
549-
]
542+
# -mno-outline-atomics: Keep intrinsics from being emitted.
543+
args = ["-fpic", "-mno-outline-atomics"]
550544
condition = "defined(__aarch64__) && defined(__linux__)"
551-
target = _ELF(
552-
host, condition, args=args, optimizer=_optimizers.OptimizerAArch64
553-
)
545+
optimizer = _optimizers.OptimizerAArch64
546+
target = _ELF(host, condition, args=args, optimizer=optimizer)
554547
elif re.fullmatch(r"i686-pc-windows-msvc", host):
555-
args = [
556-
"-DPy_NO_ENABLE_SHARED",
557-
# __attribute__((preserve_none)) is not supported
558-
"-Wno-ignored-attributes",
559-
]
548+
# -Wno-ignored-attributes: __attribute__((preserve_none)) is not supported here.
549+
args = ["-DPy_NO_ENABLE_SHARED", "-Wno-ignored-attributes"]
550+
optimizer = _optimizers.OptimizerX86Windows
560551
condition = "defined(_M_IX86)"
561-
target = _COFF(
562-
host,
563-
condition,
564-
args=args,
565-
optimizer=_optimizers.OptimizerX86Windows,
566-
prefix="_",
567-
)
552+
target = _COFF(host, condition, args=args, optimizer=optimizer, prefix="_")
568553
elif re.fullmatch(r"x86_64-apple-darwin.*", host):
569554
condition = "defined(__x86_64__) && defined(__APPLE__)"
570-
target = _MachO(host, condition, optimizer=_optimizers.OptimizerX86, prefix="_")
555+
optimizer = _optimizers.OptimizerX86
556+
target = _MachO(host, condition, optimizer=optimizer, prefix="_")
571557
elif re.fullmatch(r"x86_64-pc-windows-msvc", host):
572558
args = ["-fms-runtime-lib=dll"]
573559
condition = "defined(_M_X64)"
574-
target = _COFF(
575-
host, condition, args=args, optimizer=_optimizers.OptimizerX86Windows
576-
)
560+
optimizer = _optimizers.OptimizerX86Windows
561+
target = _COFF(host, condition, args=args, optimizer=optimizer)
577562
elif re.fullmatch(r"x86_64-.*-linux-gnu", host):
578563
args = ["-fno-pic", "-mcmodel=medium", "-mlarge-data-threshold=0"]
579564
condition = "defined(__x86_64__) && defined(__linux__)"
580-
target = _ELF(host, condition, args=args, optimizer=_optimizers.OptimizerX86)
565+
optimizer = _optimizers.OptimizerX86
566+
target = _ELF(host, condition, args=args, optimizer=optimizer)
581567
else:
582568
raise ValueError(host)
583569
return target

0 commit comments

Comments
 (0)