Skip to content

Commit 5606078

Browse files
committed
Do some textual assembly magic
1 parent 0f866cb commit 5606078

File tree

2 files changed

+250
-66
lines changed

2 files changed

+250
-66
lines changed

Tools/jit/_stencils.py

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ class HoleValue(enum.Enum):
1717

1818
# The base address of the machine code for the current uop (exposed as _JIT_ENTRY):
1919
CODE = enum.auto()
20-
# The base address of the machine code for the next uop (exposed as _JIT_CONTINUE):
21-
CONTINUE = enum.auto()
2220
# The base address of the read-only data for this uop:
2321
DATA = enum.auto()
2422
# The address of the current executor (exposed as _JIT_EXECUTOR):
@@ -97,7 +95,6 @@ class HoleValue(enum.Enum):
9795
# Translate HoleValues to C expressions:
9896
_HOLE_EXPRS = {
9997
HoleValue.CODE: "(uintptr_t)code",
100-
HoleValue.CONTINUE: "(uintptr_t)code + sizeof(code_body)",
10198
HoleValue.DATA: "(uintptr_t)data",
10299
HoleValue.EXECUTOR: "(uintptr_t)executor",
103100
# These should all have been turned into DATA values by process_relocations:
@@ -209,63 +206,22 @@ def pad(self, alignment: int) -> None:
209206
self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
210207
self.body.extend([0] * padding)
211208

212-
def add_nops(self, nop: bytes, alignment: int) -> None:
213-
"""Add NOPs until there is alignment. Fail if it is not possible."""
214-
offset = len(self.body)
215-
nop_size = len(nop)
216-
217-
# Calculate the gap to the next multiple of alignment.
218-
gap = -offset % alignment
219-
if gap:
220-
if gap % nop_size == 0:
221-
count = gap // nop_size
222-
self.body.extend(nop * count)
223-
else:
224-
raise ValueError(
225-
f"Cannot add nops of size '{nop_size}' to a body with "
226-
f"offset '{offset}' to align with '{alignment}'"
227-
)
228-
229-
def remove_jump(self) -> None:
230-
"""Remove a zero-length continuation jump, if it exists."""
231-
hole = max(self.holes, key=lambda hole: hole.offset)
232-
match hole:
233-
case Hole(
234-
offset=offset,
235-
kind="IMAGE_REL_AMD64_REL32",
236-
value=HoleValue.GOT,
237-
symbol="_JIT_CONTINUE",
238-
addend=-4,
239-
) as hole:
240-
# jmp qword ptr [rip]
241-
jump = b"\x48\xff\x25\x00\x00\x00\x00"
242-
offset -= 3
243-
case Hole(
244-
offset=offset,
245-
kind="IMAGE_REL_I386_REL32" | "R_X86_64_PLT32" | "X86_64_RELOC_BRANCH",
246-
value=HoleValue.CONTINUE,
247-
symbol=None,
248-
addend=addend,
249-
) as hole if (
250-
_signed(addend) == -4
251-
):
252-
# jmp 5
253-
jump = b"\xe9\x00\x00\x00\x00"
254-
offset -= 1
255-
case Hole(
256-
offset=offset,
257-
kind="R_AARCH64_JUMP26",
258-
value=HoleValue.CONTINUE,
259-
symbol=None,
260-
addend=0,
261-
) as hole:
262-
# b #4
263-
jump = b"\x00\x00\x00\x14"
264-
case _:
265-
return
266-
if self.body[offset:] == jump:
267-
self.body = self.body[:offset]
268-
self.holes.remove(hole)
209+
# def add_nops(self, nop: bytes, alignment: int) -> None:
210+
# """Add NOPs until there is alignment. Fail if it is not possible."""
211+
# offset = len(self.body)
212+
# nop_size = len(nop)
213+
214+
# # Calculate the gap to the next multiple of alignment.
215+
# gap = -offset % alignment
216+
# if gap:
217+
# if gap % nop_size == 0:
218+
# count = gap // nop_size
219+
# self.body.extend(nop * count)
220+
# else:
221+
# raise ValueError(
222+
# f"Cannot add nops of size '{nop_size}' to a body with "
223+
# f"offset '{offset}' to align with '{alignment}'"
224+
# )
269225

270226

271227
@dataclasses.dataclass
@@ -306,8 +262,7 @@ def process_relocations(
306262
self._trampolines.add(ordinal)
307263
hole.addend = ordinal
308264
hole.symbol = None
309-
self.code.remove_jump()
310-
self.code.add_nops(nop=nop, alignment=alignment)
265+
# self.code.add_nops(nop=nop, alignment=alignment)
311266
self.data.pad(8)
312267
for stencil in [self.code, self.data]:
313268
for hole in stencil.holes:

Tools/jit/_targets.py

Lines changed: 233 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import dataclasses
55
import hashlib
6+
import itertools
67
import json
78
import os
89
import pathlib
@@ -34,6 +35,223 @@
3435
"_R", _schema.COFFRelocation, _schema.ELFRelocation, _schema.MachORelocation
3536
)
3637

38+
inverted_branches = {}
39+
for op, nop in [
40+
("ja", "jna"),
41+
("jae", "jnae"),
42+
("jb", "jnb"),
43+
("jbe", "jnbe"),
44+
("jc", "jnc"),
45+
("je", "jne"),
46+
("jg", "jng"),
47+
("jge", "jnge"),
48+
("jl", "jnl"),
49+
("jle", "jnle"),
50+
("jo", "jno"),
51+
("jp", "jnp"),
52+
("js", "jns"),
53+
("jz", "jnz"),
54+
("jpe", "jpo"),
55+
("jcxz", None),
56+
("jecxz", None),
57+
("jrxz", None),
58+
("loop", None),
59+
("loope", None),
60+
("loopne", None),
61+
("loopnz", None),
62+
("loopz", None),
63+
]:
64+
inverted_branches[op] = nop
65+
if nop is not None:
66+
inverted_branches[nop] = op
67+
68+
69+
@dataclasses.dataclass
70+
class _Line:
71+
text: str
72+
hot: bool = dataclasses.field(init=False, default=False)
73+
predecessors: list["_Line"] = dataclasses.field(
74+
init=False, repr=False, default_factory=list
75+
)
76+
77+
78+
@dataclasses.dataclass
79+
class _Label(_Line):
80+
label: str
81+
82+
83+
@dataclasses.dataclass
84+
class _Jump(_Line):
85+
target: _Label | None
86+
87+
def remove(self) -> None:
88+
self.text = ""
89+
if self.target is not None:
90+
self.target.predecessors.remove(self)
91+
if not self.target.predecessors:
92+
self.target.text = ""
93+
self.target = None
94+
95+
def update(self, target: _Label) -> None:
96+
assert self.target is not None
97+
self.target.predecessors.remove(self)
98+
assert self.target.label in self.text
99+
self.text = self.text.replace(self.target.label, target.label)
100+
self.target = target
101+
self.target.predecessors.append(self)
102+
103+
104+
@dataclasses.dataclass
105+
class _Branch(_Line):
106+
op: str
107+
target: _Label
108+
fallthrough: _Line | None = None
109+
110+
def update(self, target: _Label) -> None:
111+
assert self.target is not None
112+
self.target.predecessors.remove(self)
113+
assert self.target.label in self.text
114+
self.text = self.text.replace(self.target.label, target.label)
115+
self.target = target
116+
self.target.predecessors.append(self)
117+
118+
def invert(self, jump: _Jump) -> bool:
119+
inverted = inverted_branches[self.op]
120+
if inverted is None or jump.target is None:
121+
return False
122+
assert self.op in self.text
123+
self.text = self.text.replace(self.op, inverted)
124+
old_target = self.target
125+
self.update(jump.target)
126+
jump.update(old_target)
127+
return True
128+
129+
130+
@dataclasses.dataclass
131+
class _Return(_Line):
132+
pass
133+
134+
135+
@dataclasses.dataclass
136+
class _Noise(_Line):
137+
pass
138+
139+
140+
def _branch(
141+
line: str, use_label: typing.Callable[[str], _Label]
142+
) -> tuple[str, _Label] | None:
143+
branch = re.match(rf"\s*({'|'.join(inverted_branches)})\s+([\w\.]+)", line)
144+
return branch and (branch.group(1), use_label(branch.group(2)))
145+
146+
147+
def _jump(line: str, use_label: typing.Callable[[str], _Label]) -> _Label | None:
148+
if jump := re.match(r"\s*jmp\s+([\w\.]+)", line):
149+
return use_label(jump.group(1))
150+
return None
151+
152+
153+
def _label(line: str, use_label: typing.Callable[[str], _Label]) -> _Label | None:
154+
label = re.match(r"\s*([\w\.]+):", line)
155+
return label and use_label(label.group(1))
156+
157+
158+
def _return(line: str) -> bool:
159+
return re.match(r"\s*ret\s+", line) is not None
160+
161+
162+
def _noise(line: str) -> bool:
163+
return re.match(r"\s*[#\.]|\s*$", line) is not None
164+
165+
166+
def _apply_asm_transformations(path: pathlib.Path) -> None:
167+
labels = {}
168+
169+
def use_label(label: str) -> _Label:
170+
if label not in labels:
171+
labels[label] = _Label("", label)
172+
return labels[label]
173+
174+
def new_line(text: str) -> _Line:
175+
if branch := _branch(text, use_label):
176+
op, label = branch
177+
line = _Branch(text, op, label)
178+
label.predecessors.append(line)
179+
return line
180+
if label := _jump(text, use_label):
181+
line = _Jump(text, label)
182+
label.predecessors.append(line)
183+
return line
184+
if line := _label(text, use_label):
185+
assert line.text == ""
186+
line.text = text
187+
return line
188+
if _return(text):
189+
return _Return(text)
190+
if _noise(text):
191+
return _Noise(text)
192+
return _Line(text)
193+
194+
# Build graph:
195+
lines = []
196+
line = _Noise("") # Dummy.
197+
with path.open() as file:
198+
for i, text in enumerate(file):
199+
new = new_line(text)
200+
if not isinstance(line, (_Jump, _Return)):
201+
new.predecessors.append(line)
202+
lines.append(new)
203+
line = new
204+
for i, line in enumerate(reversed(lines)):
205+
if not isinstance(line, (_Label, _Noise)):
206+
break
207+
new = new_line("_JIT_CONTINUE:\n")
208+
lines.insert(len(lines) - i, new)
209+
line = new
210+
# Mark hot lines:
211+
todo = labels["_JIT_CONTINUE"].predecessors.copy()
212+
while todo:
213+
line = todo.pop()
214+
line.hot = True
215+
for predecessor in line.predecessors:
216+
if not predecessor.hot:
217+
todo.append(predecessor)
218+
for pair in itertools.pairwise(
219+
filter(lambda line: line.text and not isinstance(line, _Noise), lines)
220+
):
221+
match pair:
222+
case (_Branch(hot=True) as branch, _Jump(hot=False) as jump):
223+
branch.invert(jump)
224+
jump.hot = True
225+
for pair in itertools.pairwise(lines):
226+
match pair:
227+
case (_Jump() | _Return(), _):
228+
pass
229+
case (_Line(hot=True), _Line(hot=False) as cold):
230+
cold.hot = True
231+
# Reorder blocks:
232+
hot = []
233+
cold = []
234+
for line in lines:
235+
if line.hot:
236+
hot.append(line)
237+
else:
238+
cold.append(line)
239+
lines = hot + cold
240+
# Remove zero-length jumps:
241+
again = True
242+
while again:
243+
again = False
244+
for pair in itertools.pairwise(
245+
filter(lambda line: line.text and not isinstance(line, _Noise), lines)
246+
):
247+
match pair:
248+
case (_Jump(target=target) as jump, label) if target is label:
249+
jump.remove()
250+
again = True
251+
# Write new assembly:
252+
with path.open("w") as file:
253+
file.writelines(line.text for line in lines)
254+
37255

38256
@dataclasses.dataclass
39257
class _Target(typing.Generic[_S, _R]):
@@ -118,8 +336,9 @@ def _handle_relocation(
118336
async def _compile(
119337
self, opname: str, c: pathlib.Path, tempdir: pathlib.Path
120338
) -> _stencils.StencilGroup:
339+
s = tempdir / f"{opname}.s"
121340
o = tempdir / f"{opname}.o"
122-
args = [
341+
args_s = [
123342
f"--target={self.triple}",
124343
"-DPy_BUILD_CORE_MODULE",
125344
"-D_DEBUG" if self.debug else "-DNDEBUG",
@@ -133,7 +352,8 @@ async def _compile(
133352
f"-I{CPYTHON / 'Python'}",
134353
f"-I{CPYTHON / 'Tools' / 'jit'}",
135354
"-O3",
136-
"-c",
355+
"-S",
356+
# "-c",
137357
# Shorten full absolute file paths in the generated code (like the
138358
# __FILE__ macro and assert failure messages) for reproducibility:
139359
f"-ffile-prefix-map={CPYTHON}=.",
@@ -152,11 +372,20 @@ async def _compile(
152372
"-fno-stack-protector",
153373
"-std=c11",
154374
"-o",
155-
f"{o}",
375+
f"{s}",
156376
f"{c}",
157377
*self.args,
158378
]
159-
await _llvm.run("clang", args, echo=self.verbose)
379+
await _llvm.run("clang", args_s, echo=self.verbose)
380+
_apply_asm_transformations(s)
381+
args_o = [
382+
f"--target={self.triple}",
383+
"-c",
384+
"-o",
385+
f"{o}",
386+
f"{s}",
387+
]
388+
await _llvm.run("clang", args_o, echo=self.verbose)
160389
return await self._parse(o)
161390

162391
async def _build_stencils(self) -> dict[str, _stencils.StencilGroup]:

0 commit comments

Comments
 (0)