Skip to content

Commit 858624a

Browse files
committed
Switch to a linked list
1 parent 5606078 commit 858624a

File tree

1 file changed

+123
-92
lines changed

1 file changed

+123
-92
lines changed

Tools/jit/_targets.py

Lines changed: 123 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -35,44 +35,59 @@
3535
"_R", _schema.COFFRelocation, _schema.ELFRelocation, _schema.MachORelocation
3636
)
3737

38-
inverted_branches = {}
38+
branches = {}
3939
for op, nop in [
4040
("ja", "jna"),
4141
("jae", "jnae"),
4242
("jb", "jnb"),
4343
("jbe", "jnbe"),
4444
("jc", "jnc"),
45+
("jcxz", None),
4546
("je", "jne"),
47+
("jecxz", None),
4648
("jg", "jng"),
4749
("jge", "jnge"),
4850
("jl", "jnl"),
4951
("jle", "jnle"),
5052
("jo", "jno"),
5153
("jp", "jnp"),
52-
("js", "jns"),
53-
("jz", "jnz"),
5454
("jpe", "jpo"),
55-
("jcxz", None),
56-
("jecxz", None),
5755
("jrxz", None),
56+
("js", "jns"),
57+
("jz", "jnz"),
5858
("loop", None),
5959
("loope", None),
6060
("loopne", None),
6161
("loopnz", None),
6262
("loopz", None),
6363
]:
64-
inverted_branches[op] = nop
64+
branches[op] = nop
6565
if nop is not None:
66-
inverted_branches[nop] = op
66+
branches[nop] = op
6767

6868

6969
@dataclasses.dataclass
7070
class _Line:
71+
fallthrough: typing.ClassVar[bool] = True
7172
text: str
7273
hot: bool = dataclasses.field(init=False, default=False)
7374
predecessors: list["_Line"] = dataclasses.field(
7475
init=False, repr=False, default_factory=list
7576
)
77+
link: "_Line | None" = dataclasses.field(init=False, repr=False, default=None)
78+
79+
def heat(self) -> None:
80+
if self.hot:
81+
return
82+
self.hot = True
83+
for predecessor in self.predecessors:
84+
predecessor.heat()
85+
if self.fallthrough and self.link is not None:
86+
self.link.heat()
87+
88+
def optimize(self) -> None:
89+
if self.link is not None:
90+
self.link.optimize()
7691

7792

7893
@dataclasses.dataclass
@@ -82,18 +97,23 @@ class _Label(_Line):
8297

8398
@dataclasses.dataclass
8499
class _Jump(_Line):
85-
target: _Label | None
100+
fallthrough = False
101+
target: _Label
102+
103+
def optimize(self) -> None:
104+
super().optimize()
105+
target_aliases = _aliases(self.target)
106+
if any(alias in target_aliases for alias in _aliases(self.link)):
107+
self.remove()
86108

87109
def remove(self) -> None:
88-
self.text = ""
89-
if self.target is not None:
90-
self.target.predecessors.remove(self)
91-
if not self.target.predecessors:
92-
self.target.text = ""
93-
self.target = None
110+
[predecessor] = self.predecessors
111+
assert predecessor.link is self
112+
self.target.predecessors.remove(self)
113+
predecessor.link = self.link
114+
self.target.predecessors.append(predecessor)
94115

95116
def update(self, target: _Label) -> None:
96-
assert self.target is not None
97117
self.target.predecessors.remove(self)
98118
assert self.target.label in self.text
99119
self.text = self.text.replace(self.target.label, target.label)
@@ -105,22 +125,28 @@ def update(self, target: _Label) -> None:
105125
class _Branch(_Line):
106126
op: str
107127
target: _Label
108-
fallthrough: _Line | None = None
128+
129+
def optimize(self) -> None:
130+
super().optimize()
131+
if self.target.hot:
132+
for jump in _aliases(self.link):
133+
if isinstance(jump, _Jump) and self.invert(jump):
134+
jump.optimize()
109135

110136
def update(self, target: _Label) -> None:
111-
assert self.target is not None
112137
self.target.predecessors.remove(self)
113138
assert self.target.label in self.text
114139
self.text = self.text.replace(self.target.label, target.label)
115140
self.target = target
116141
self.target.predecessors.append(self)
117142

118143
def invert(self, jump: _Jump) -> bool:
119-
inverted = inverted_branches[self.op]
120-
if inverted is None or jump.target is None:
144+
inverted = branches[self.op]
145+
if inverted is None:
121146
return False
122147
assert self.op in self.text
123148
self.text = self.text.replace(self.op, inverted)
149+
self.op = inverted
124150
old_target = self.target
125151
self.update(jump.target)
126152
jump.update(old_target)
@@ -129,7 +155,7 @@ def invert(self, jump: _Jump) -> bool:
129155

130156
@dataclasses.dataclass
131157
class _Return(_Line):
132-
pass
158+
fallthrough = False
133159

134160

135161
@dataclasses.dataclass
@@ -140,7 +166,7 @@ class _Noise(_Line):
140166
def _branch(
141167
line: str, use_label: typing.Callable[[str], _Label]
142168
) -> tuple[str, _Label] | None:
143-
branch = re.match(rf"\s*({'|'.join(inverted_branches)})\s+([\w\.]+)", line)
169+
branch = re.match(rf"\s*({'|'.join(branches)})\s+([\w\.]+)", line)
144170
return branch and (branch.group(1), use_label(branch.group(2)))
145171

146172

@@ -163,25 +189,57 @@ def _noise(line: str) -> bool:
163189
return re.match(r"\s*[#\.]|\s*$", line) is not None
164190

165191

166-
def _apply_asm_transformations(path: pathlib.Path) -> None:
167-
labels = {}
192+
def _aliases(line: _Line | None) -> list[_Line]:
193+
aliases = []
194+
while line is not None and isinstance(line, (_Label, _Noise)):
195+
aliases.append(line)
196+
line = line.link
197+
if line is not None:
198+
aliases.append(line)
199+
return aliases
168200

169-
def use_label(label: str) -> _Label:
170-
if label not in labels:
171-
labels[label] = _Label("", label)
172-
return labels[label]
173201

174-
def new_line(text: str) -> _Line:
175-
if branch := _branch(text, use_label):
202+
@dataclasses.dataclass
203+
class _AssemblyTransformer:
204+
_path: pathlib.Path
205+
_alignment: int = 1
206+
_lines: _Line = dataclasses.field(init=False)
207+
_labels: dict[str, _Label] = dataclasses.field(init=False, default_factory=dict)
208+
_ran: bool = dataclasses.field(init=False, default=False)
209+
210+
def __post_init__(self) -> None:
211+
dummy = current = _Noise("")
212+
for line in self._path.read_text().splitlines(True):
213+
new = self._new_line(line)
214+
if current.fallthrough:
215+
new.predecessors.append(current)
216+
current.link = new
217+
current = new
218+
assert dummy.link is not None
219+
self._lines = dummy.link
220+
221+
def __iter__(self) -> typing.Iterator[_Line]:
222+
line = self._lines
223+
while line is not None:
224+
yield line
225+
line = line.link
226+
227+
def _use_label(self, label: str) -> _Label:
228+
if label not in self._labels:
229+
self._labels[label] = _Label("", label)
230+
return self._labels[label]
231+
232+
def _new_line(self, text: str) -> _Line:
233+
if branch := _branch(text, self._use_label):
176234
op, label = branch
177235
line = _Branch(text, op, label)
178236
label.predecessors.append(line)
179237
return line
180-
if label := _jump(text, use_label):
238+
if label := _jump(text, self._use_label):
181239
line = _Jump(text, label)
182240
label.predecessors.append(line)
183241
return line
184-
if line := _label(text, use_label):
242+
if line := _label(text, self._use_label):
185243
assert line.text == ""
186244
line.text = text
187245
return line
@@ -191,66 +249,39 @@ def new_line(text: str) -> _Line:
191249
return _Noise(text)
192250
return _Line(text)
193251

194-
# Build graph:
195-
lines = []
196-
line = _Noise("") # Dummy.
197-
with path.open() as file:
198-
for i, text in enumerate(file):
199-
new = new_line(text)
200-
if not isinstance(line, (_Jump, _Return)):
201-
new.predecessors.append(line)
202-
lines.append(new)
203-
line = new
204-
for i, line in enumerate(reversed(lines)):
205-
if not isinstance(line, (_Label, _Noise)):
206-
break
207-
new = new_line("_JIT_CONTINUE:\n")
208-
lines.insert(len(lines) - i, new)
209-
line = new
210-
# Mark hot lines:
211-
todo = labels["_JIT_CONTINUE"].predecessors.copy()
212-
while todo:
213-
line = todo.pop()
214-
line.hot = True
215-
for predecessor in line.predecessors:
216-
if not predecessor.hot:
217-
todo.append(predecessor)
218-
for pair in itertools.pairwise(
219-
filter(lambda line: line.text and not isinstance(line, _Noise), lines)
220-
):
221-
match pair:
222-
case (_Branch(hot=True) as branch, _Jump(hot=False) as jump):
223-
branch.invert(jump)
224-
jump.hot = True
225-
for pair in itertools.pairwise(lines):
226-
match pair:
227-
case (_Jump() | _Return(), _):
228-
pass
229-
case (_Line(hot=True), _Line(hot=False) as cold):
230-
cold.hot = True
231-
# Reorder blocks:
232-
hot = []
233-
cold = []
234-
for line in lines:
235-
if line.hot:
236-
hot.append(line)
237-
else:
238-
cold.append(line)
239-
lines = hot + cold
240-
# Remove zero-length jumps:
241-
again = True
242-
while again:
243-
again = False
244-
for pair in itertools.pairwise(
245-
filter(lambda line: line.text and not isinstance(line, _Noise), lines)
246-
):
247-
match pair:
248-
case (_Jump(target=target) as jump, label) if target is label:
249-
jump.remove()
250-
again = True
251-
# Write new assembly:
252-
with path.open("w") as file:
253-
file.writelines(line.text for line in lines)
252+
def _dump(self) -> str:
253+
return "".join(line.text for line in self)
254+
255+
def _break_on(self, name: str) -> None:
256+
if self._path.stem == name:
257+
print(self._dump())
258+
breakpoint()
259+
260+
def run(self) -> None:
261+
assert not self._ran
262+
self._ran = True
263+
last_line = None
264+
for line in self:
265+
if not isinstance(line, (_Label, _Noise)):
266+
last_line = line
267+
assert last_line is not None
268+
new = self._new_line(f".balign {self._alignment}\n")
269+
new.link = last_line.link
270+
last_line.link = new
271+
new = self._new_line("_JIT_CONTINUE:\n")
272+
new.link = last_line.link
273+
last_line.link = new
274+
# Mark hot lines and optimize:
275+
recursion_limit = sys.getrecursionlimit()
276+
sys.setrecursionlimit(10_000)
277+
try:
278+
self._labels["_JIT_CONTINUE"].heat()
279+
# self._break_on("_BUILD_TUPLE")
280+
self._lines.optimize()
281+
finally:
282+
sys.setrecursionlimit(recursion_limit)
283+
# Write new assembly:
284+
self._path.write_text(self._dump())
254285

255286

256287
@dataclasses.dataclass
@@ -377,7 +408,7 @@ async def _compile(
377408
*self.args,
378409
]
379410
await _llvm.run("clang", args_s, echo=self.verbose)
380-
_apply_asm_transformations(s)
411+
_AssemblyTransformer(s, self.alignment).run()
381412
args_o = [
382413
f"--target={self.triple}",
383414
"-c",

0 commit comments

Comments
 (0)