Skip to content

Commit 9ebb32c

Browse files
committed
Add AArch64 stub
1 parent da72079 commit 9ebb32c

File tree

3 files changed

+151
-142
lines changed

3 files changed

+151
-142
lines changed

Tools/jit/_optimizers.py

Lines changed: 138 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -4,69 +4,34 @@
44
import re
55
import typing
66

7-
branches = {}
8-
for op, nop in [
9-
("ja", "jna"),
10-
("jae", "jnae"),
11-
("jb", "jnb"),
12-
("jbe", "jnbe"),
13-
("jc", "jnc"),
14-
("jcxz", None),
15-
("je", "jne"),
16-
("jecxz", None),
17-
("jg", "jng"),
18-
("jge", "jnge"),
19-
("jl", "jnl"),
20-
("jle", "jnle"),
21-
("jo", "jno"),
22-
("jp", "jnp"),
23-
("jpe", "jpo"),
24-
("jrxz", None),
25-
("js", "jns"),
26-
("jz", "jnz"),
27-
("loop", None),
28-
("loope", None),
29-
("loopne", None),
30-
("loopnz", None),
31-
("loopz", None),
32-
]:
33-
branches[op] = nop
34-
if nop:
35-
branches[nop] = op
36-
37-
38-
def _get_branch(line: str) -> str | None:
39-
branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line)
40-
return branch and branch[2]
41-
42-
43-
def _invert_branch(line: str, label: str) -> str | None:
44-
branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line)
45-
assert branch
46-
inverted = branches.get(branch[1])
47-
if not inverted:
48-
return None
49-
line = line.replace(branch[1], inverted, 1) # XXX
50-
line = line.replace(branch[2], label, 1) # XXX
51-
return line
52-
53-
54-
def _get_jump(line: str) -> str | None:
55-
jump = re.match(r"\s*(?:rex64\s+)?jmpq?\s+\*?([\w\.]+)", line)
56-
return jump and jump[1]
57-
58-
59-
def _get_label(line: str) -> str | None:
60-
label = re.match(r"\s*([\w\.]+):", line)
61-
return label and label[1]
62-
63-
64-
def _is_return(line: str) -> bool:
65-
return re.match(r"\s*ret\s+", line) is not None
66-
67-
68-
def _is_noise(line: str) -> bool:
69-
return re.match(r"\s*([#\.]|$)", line) is not None
7+
_RE_NEVER_MATCH = re.compile(r"(?!)")
8+
9+
_X86_BRANCHES = {
10+
"ja": "jna",
11+
"jae": "jnae",
12+
"jb": "jnb",
13+
"jbe": "jnbe",
14+
"jc": "jnc",
15+
"jcxz": None,
16+
"je": "jne",
17+
"jecxz": None,
18+
"jg": "jng",
19+
"jge": "jnge",
20+
"jl": "jnl",
21+
"jle": "jnle",
22+
"jo": "jno",
23+
"jp": "jnp",
24+
"jpe": "jpo",
25+
"jrxz": None,
26+
"js": "jns",
27+
"jz": "jnz",
28+
"loop": None,
29+
"loope": None,
30+
"loopne": None,
31+
"loopnz": None,
32+
"loopz": None,
33+
}
34+
_X86_BRANCHES |= {v: k for k, v in _X86_BRANCHES.items() if v}
7035

7136

7237
@dataclasses.dataclass
@@ -91,57 +56,111 @@ def resolve(self) -> typing.Self:
9156
return self
9257

9358

94-
class _Labels(dict):
95-
def __missing__(self, key: str) -> _Block:
96-
self[key] = _Block(key)
97-
return self[key]
98-
99-
10059
@dataclasses.dataclass
10160
class Optimizer:
10261

10362
path: pathlib.Path
10463
_: dataclasses.KW_ONLY
10564
prefix: str = ""
10665
_graph: _Block = dataclasses.field(init=False)
107-
_labels: _Labels = dataclasses.field(init=False, default_factory=_Labels)
108-
109-
_re_branch: typing.ClassVar[re.Pattern[str]] # Two groups: instruction and target.
110-
_re_jump: typing.ClassVar[re.Pattern[str]] # One group: target.
111-
_re_return: typing.ClassVar[re.Pattern[str]] # No groups.
66+
_labels: dict = dataclasses.field(init=False, default_factory=dict)
67+
_alignment: typing.ClassVar[int] = 1
68+
_branches: typing.ClassVar[dict[str, str | None]] = {}
69+
_re_branch: typing.ClassVar[re.Pattern[str]] = (
70+
_RE_NEVER_MATCH # Two groups: instruction and target.
71+
)
72+
_re_jump: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH # One group: target.
73+
_re_label: typing.ClassVar[re.Pattern[str]] = re.compile(
74+
r"\s*(?P<label>[\w\.]+):"
75+
) # One group: label.
76+
_re_noise: typing.ClassVar[re.Pattern[str]] = re.compile(
77+
r"\s*(?:[#\.]|$)"
78+
) # No groups.
79+
_re_return: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH # No groups.
11280

11381
def __post_init__(self) -> None:
11482
text = self._preprocess(self.path.read_text())
11583
self._graph = block = self._new_block()
11684
for line in text.splitlines():
117-
if label := _get_label(line):
118-
block.link = block = self._labels[label]
85+
if label := self._parse_label(line):
86+
block.link = block = self._lookup_label(label)
87+
block.noise.append(line)
88+
continue
11989
elif block.target or not block.fallthrough:
12090
block.link = block = self._new_block()
121-
if _is_noise(line) or _get_label(line):
91+
if self._parse_noise(line):
12292
if block.instructions:
12393
block.link = block = self._new_block()
12494
block.noise.append(line)
12595
continue
12696
block.instructions.append(line)
127-
if target := _get_branch(line):
128-
block.target = self._labels[target]
97+
if target := self._parse_branch(line):
98+
block.target = self._lookup_label(target)
12999
assert block.fallthrough
130-
elif target := _get_jump(line):
131-
block.target = self._labels[target]
100+
elif target := self._parse_jump(line):
101+
block.target = self._lookup_label(target)
132102
block.fallthrough = False
133-
elif _is_return(line):
103+
elif self._parse_return(line):
134104
assert not block.target
135105
block.fallthrough = False
136106

137-
def _new_block(self, label: str | None = None) -> _Block:
138-
if not label:
139-
label = f"{self.prefix}_JIT_LABEL_{len(self._labels)}"
140-
assert label not in self._labels, label
141-
block = self._labels[label] = _Block(label, [f"{label}:"])
107+
@classmethod
108+
def _parse_branch(cls, line: str) -> str | None:
109+
branch = cls._re_branch.match(line)
110+
return branch and branch["target"]
111+
112+
@classmethod
113+
def _parse_jump(cls, line: str) -> str | None:
114+
jump = cls._re_jump.match(line)
115+
return jump and jump["target"]
116+
117+
@classmethod
118+
def _parse_label(cls, line: str) -> str | None:
119+
label = cls._re_label.match(line)
120+
return label and label["label"]
121+
122+
@classmethod
123+
def _parse_noise(cls, line: str) -> bool:
124+
return cls._re_noise.match(line) is not None
125+
126+
@classmethod
127+
def _parse_return(cls, line: str) -> bool:
128+
return cls._re_return.match(line) is not None
129+
130+
@classmethod
131+
def _invert_branch(cls, line: str, target: str) -> str | None:
132+
branch = cls._re_branch.match(line)
133+
assert branch
134+
inverted = cls._branches.get(branch["instruction"])
135+
if not inverted:
136+
return None
137+
(a, b), (c, d) = branch.span("instruction"), branch.span("target")
138+
return "".join([line[:a], inverted, line[b:c], target, line[d:]])
139+
140+
@classmethod
141+
def _thread_jump(cls, line: str, target: str) -> str:
142+
jump = cls._re_jump.match(line) or cls._re_branch.match(line)
143+
assert jump
144+
a, b = jump.span("target")
145+
return "".join([line[:a], target, line[b:]])
146+
147+
@staticmethod
148+
def _create_jump(target: str) -> str | None:
149+
return None
150+
151+
def _lookup_label(self, label: str) -> _Block:
152+
if label not in self._labels:
153+
self._labels[label] = _Block(label)
154+
return self._labels[label]
155+
156+
def _new_block(self) -> _Block:
157+
label = f"{self.prefix}_JIT_LABEL_{len(self._labels)}"
158+
block = self._lookup_label(label)
159+
block.noise.append(f"{label}:")
142160
return block
143161

144-
def _preprocess(self, text: str) -> str:
162+
@staticmethod
163+
def _preprocess(text: str) -> str:
145164
return text
146165

147166
def _blocks(self) -> typing.Generator[_Block, None, None]:
@@ -159,9 +178,11 @@ def _insert_continue_label(self) -> None:
159178
for end in reversed(list(self._blocks())):
160179
if end.instructions:
161180
break
162-
continuation = self._labels[f"{self.prefix}_JIT_CONTINUE"]
181+
align = self._new_block()
182+
align.noise.append(f"\t.balign\t{self._alignment}")
183+
continuation = self._lookup_label(f"{self.prefix}_JIT_CONTINUE")
163184
continuation.noise.append(f"{continuation.label}:")
164-
end.link, continuation.link = continuation, end.link
185+
end.link, align.link, continuation.link = align, continuation, end.link
165186

166187
def _mark_hot_blocks(self) -> None:
167188
predecessors = collections.defaultdict(set)
@@ -170,7 +191,7 @@ def _mark_hot_blocks(self) -> None:
170191
predecessors[block.target].add(block)
171192
if block.fallthrough and block.link:
172193
predecessors[block.link].add(block)
173-
todo = [self._labels[f"{self.prefix}_JIT_CONTINUE"]]
194+
todo = [self._lookup_label(f"{self.prefix}_JIT_CONTINUE")]
174195
while todo:
175196
block = todo.pop()
176197
block.hot = True
@@ -185,43 +206,41 @@ def _invert_hot_branches(self) -> None:
185206
if (
186207
block.fallthrough
187208
and block.target
188-
and block.link
189209
and block.target.hot
210+
and block.link
190211
and not block.link.hot
191212
):
192-
# Turn...
193-
# branch hot
194-
# ...into..
195-
# opposite-branch ._JIT_LABEL_N
196-
# jmp hot
197-
# ._JIT_LABEL_N:
198213
label_block = self._new_block()
199-
inverted = _invert_branch(block.instructions[-1], label_block.label)
200-
if inverted is None:
214+
branch = block.instructions[-1]
215+
inverted = self._invert_branch(branch, label_block.label)
216+
jump = self._create_jump(block.target.label)
217+
if inverted is None or jump is None:
201218
continue
202219
jump_block = self._new_block()
203-
jump_block.instructions.append(f"\tjmp\t{block.target.label}")
220+
jump_block.instructions.append(jump)
204221
jump_block.target = block.target
205222
jump_block.fallthrough = False
206223
block.instructions[-1] = inverted
207224
block.target = label_block
208-
label_block.link = block.link
209-
jump_block.link = label_block
210-
block.link = jump_block
225+
block.link, jump_block.link, label_block.link = (
226+
jump_block,
227+
label_block,
228+
block.link,
229+
)
211230

212231
def _thread_jumps(self) -> None:
213232
for block in self._blocks():
214233
while block.target:
215-
label = block.target.label
216234
target = block.target.resolve()
217235
if (
218236
not target.fallthrough
219237
and target.target
220238
and len(target.instructions) == 1
221239
):
222-
block.instructions[-1] = block.instructions[-1].replace(
223-
label, target.target.label
224-
) # XXX
240+
jump = block.instructions[-1]
241+
block.instructions[-1] = self._thread_jump(
242+
jump, target.target.label
243+
)
225244
block.target = target.target
226245
else:
227246
break
@@ -273,14 +292,24 @@ def run(self) -> None:
273292
self.path.write_text("\n".join(self._lines()))
274293

275294

295+
class OptimizerAArch64(Optimizer):
296+
# TODO: @diegorusso
297+
_alignment = 8
298+
299+
276300
class OptimizerX86(Optimizer):
277301

302+
_branches = _X86_BRANCHES
278303
_re_branch = re.compile(
279-
rf"\s*(?P<instruction>{'|'.join(branches)})\s+(?P<target>[\w\.]+)"
304+
rf"\s*(?P<instruction>{'|'.join(_X86_BRANCHES)})\s+(?P<target>[\w\.]+)"
280305
)
281306
_re_jump = re.compile(r"\s*jmp\s+(?P<target>[\w\.]+)")
282307
_re_return = re.compile(r"\s*ret\b")
283308

309+
@staticmethod
310+
def _create_jump(target: str) -> str:
311+
return f"\tjmp\t{target}"
312+
284313

285314
class OptimizerX86Windows(OptimizerX86):
286315

Tools/jit/_stencils.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -206,23 +206,6 @@ def pad(self, alignment: int) -> None:
206206
self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
207207
self.body.extend([0] * padding)
208208

209-
# def add_nops(self, nop: bytes, alignment: int) -> None:
210-
# """Add NOPs until there is alignment. Fail if it is not possible."""
211-
# offset = len(self.body)
212-
# nop_size = len(nop)
213-
214-
# # Calculate the gap to the next multiple of alignment.
215-
# gap = -offset % alignment
216-
# if gap:
217-
# if gap % nop_size == 0:
218-
# count = gap // nop_size
219-
# self.body.extend(nop * count)
220-
# else:
221-
# raise ValueError(
222-
# f"Cannot add nops of size '{nop_size}' to a body with "
223-
# f"offset '{offset}' to align with '{alignment}'"
224-
# )
225-
226209

227210
@dataclasses.dataclass
228211
class StencilGroup:
@@ -240,9 +223,7 @@ class StencilGroup:
240223
_got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
241224
_trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
242225

243-
def process_relocations(
244-
self, known_symbols: dict[str, int], *, alignment: int = 1, nop: bytes = b""
245-
) -> None:
226+
def process_relocations(self, known_symbols: dict[str, int]) -> None:
246227
"""Fix up all GOT and internal relocations for this stencil group."""
247228
for hole in self.code.holes.copy():
248229
if (
@@ -262,7 +243,6 @@ def process_relocations(
262243
self._trampolines.add(ordinal)
263244
hole.addend = ordinal
264245
hole.symbol = None
265-
# self.code.add_nops(nop=nop, alignment=alignment)
266246
self.data.pad(8)
267247
for stencil in [self.code, self.data]:
268248
for hole in stencil.holes:

0 commit comments

Comments
 (0)