3535 "_R" , _schema .COFFRelocation , _schema .ELFRelocation , _schema .MachORelocation
3636)
3737
38- inverted_branches = {}
38+ branches = {}
3939for op , nop in [
4040 ("ja" , "jna" ),
4141 ("jae" , "jnae" ),
4242 ("jb" , "jnb" ),
4343 ("jbe" , "jnbe" ),
4444 ("jc" , "jnc" ),
45+ ("jcxz" , None ),
4546 ("je" , "jne" ),
47+ ("jecxz" , None ),
4648 ("jg" , "jng" ),
4749 ("jge" , "jnge" ),
4850 ("jl" , "jnl" ),
4951 ("jle" , "jnle" ),
5052 ("jo" , "jno" ),
5153 ("jp" , "jnp" ),
52- ("js" , "jns" ),
53- ("jz" , "jnz" ),
5454 ("jpe" , "jpo" ),
55- ("jcxz" , None ),
56- ("jecxz" , None ),
5755 ("jrxz" , None ),
56+ ("js" , "jns" ),
57+ ("jz" , "jnz" ),
5858 ("loop" , None ),
5959 ("loope" , None ),
6060 ("loopne" , None ),
6161 ("loopnz" , None ),
6262 ("loopz" , None ),
6363]:
64- inverted_branches [op ] = nop
64+ branches [op ] = nop
6565 if nop is not None :
66- inverted_branches [nop ] = op
66+ branches [nop ] = op
6767
6868
6969@dataclasses .dataclass
7070class _Line :
71+ fallthrough : typing .ClassVar [bool ] = True
7172 text : str
7273 hot : bool = dataclasses .field (init = False , default = False )
7374 predecessors : list ["_Line" ] = dataclasses .field (
7475 init = False , repr = False , default_factory = list
7576 )
77+ link : "_Line | None" = dataclasses .field (init = False , repr = False , default = None )
78+
79+ def heat (self ) -> None :
80+ if self .hot :
81+ return
82+ self .hot = True
83+ for predecessor in self .predecessors :
84+ predecessor .heat ()
85+ if self .fallthrough and self .link is not None :
86+ self .link .heat ()
87+
88+ def optimize (self ) -> None :
89+ if self .link is not None :
90+ self .link .optimize ()
7691
7792
7893@dataclasses .dataclass
@@ -82,18 +97,23 @@ class _Label(_Line):
8297
8398@dataclasses .dataclass
8499class _Jump (_Line ):
85- target : _Label | None
100+ fallthrough = False
101+ target : _Label
102+
103+ def optimize (self ) -> None :
104+ super ().optimize ()
105+ target_aliases = _aliases (self .target )
106+ if any (alias in target_aliases for alias in _aliases (self .link )):
107+ self .remove ()
86108
87109 def remove (self ) -> None :
88- self .text = ""
89- if self .target is not None :
90- self .target .predecessors .remove (self )
91- if not self .target .predecessors :
92- self .target .text = ""
93- self .target = None
110+ [predecessor ] = self .predecessors
111+ assert predecessor .link is self
112+ self .target .predecessors .remove (self )
113+ predecessor .link = self .link
114+ self .target .predecessors .append (predecessor )
94115
95116 def update (self , target : _Label ) -> None :
96- assert self .target is not None
97117 self .target .predecessors .remove (self )
98118 assert self .target .label in self .text
99119 self .text = self .text .replace (self .target .label , target .label )
@@ -105,22 +125,28 @@ def update(self, target: _Label) -> None:
105125class _Branch (_Line ):
106126 op : str
107127 target : _Label
108- fallthrough : _Line | None = None
128+
129+ def optimize (self ) -> None :
130+ super ().optimize ()
131+ if self .target .hot :
132+ for jump in _aliases (self .link ):
133+ if isinstance (jump , _Jump ) and self .invert (jump ):
134+ jump .optimize ()
109135
110136 def update (self , target : _Label ) -> None :
111- assert self .target is not None
112137 self .target .predecessors .remove (self )
113138 assert self .target .label in self .text
114139 self .text = self .text .replace (self .target .label , target .label )
115140 self .target = target
116141 self .target .predecessors .append (self )
117142
118143 def invert (self , jump : _Jump ) -> bool :
119- inverted = inverted_branches [self .op ]
120- if inverted is None or jump . target is None :
144+ inverted = branches [self .op ]
145+ if inverted is None :
121146 return False
122147 assert self .op in self .text
123148 self .text = self .text .replace (self .op , inverted )
149+ self .op = inverted
124150 old_target = self .target
125151 self .update (jump .target )
126152 jump .update (old_target )
@@ -129,7 +155,7 @@ def invert(self, jump: _Jump) -> bool:
129155
130156@dataclasses .dataclass
131157class _Return (_Line ):
132- pass
158+ fallthrough = False
133159
134160
135161@dataclasses .dataclass
@@ -140,7 +166,7 @@ class _Noise(_Line):
140166def _branch (
141167 line : str , use_label : typing .Callable [[str ], _Label ]
142168) -> tuple [str , _Label ] | None :
143- branch = re .match (rf"\s*({ '|' .join (inverted_branches )} )\s+([\w\.]+)" , line )
169+ branch = re .match (rf"\s*({ '|' .join (branches )} )\s+([\w\.]+)" , line )
144170 return branch and (branch .group (1 ), use_label (branch .group (2 )))
145171
146172
@@ -163,25 +189,57 @@ def _noise(line: str) -> bool:
163189 return re .match (r"\s*[#\.]|\s*$" , line ) is not None
164190
165191
166- def _apply_asm_transformations (path : pathlib .Path ) -> None :
167- labels = {}
192+ def _aliases (line : _Line | None ) -> list [_Line ]:
193+ aliases = []
194+ while line is not None and isinstance (line , (_Label , _Noise )):
195+ aliases .append (line )
196+ line = line .link
197+ if line is not None :
198+ aliases .append (line )
199+ return aliases
168200
169- def use_label (label : str ) -> _Label :
170- if label not in labels :
171- labels [label ] = _Label ("" , label )
172- return labels [label ]
173201
174- def new_line (text : str ) -> _Line :
175- if branch := _branch (text , use_label ):
202+ @dataclasses .dataclass
203+ class _AssemblyTransformer :
204+ _path : pathlib .Path
205+ _alignment : int = 1
206+ _lines : _Line = dataclasses .field (init = False )
207+ _labels : dict [str , _Label ] = dataclasses .field (init = False , default_factory = dict )
208+ _ran : bool = dataclasses .field (init = False , default = False )
209+
210+ def __post_init__ (self ) -> None :
211+ dummy = current = _Noise ("" )
212+ for line in self ._path .read_text ().splitlines (True ):
213+ new = self ._new_line (line )
214+ if current .fallthrough :
215+ new .predecessors .append (current )
216+ current .link = new
217+ current = new
218+ assert dummy .link is not None
219+ self ._lines = dummy .link
220+
221+ def __iter__ (self ) -> typing .Iterator [_Line ]:
222+ line = self ._lines
223+ while line is not None :
224+ yield line
225+ line = line .link
226+
227+ def _use_label (self , label : str ) -> _Label :
228+ if label not in self ._labels :
229+ self ._labels [label ] = _Label ("" , label )
230+ return self ._labels [label ]
231+
232+ def _new_line (self , text : str ) -> _Line :
233+ if branch := _branch (text , self ._use_label ):
176234 op , label = branch
177235 line = _Branch (text , op , label )
178236 label .predecessors .append (line )
179237 return line
180- if label := _jump (text , use_label ):
238+ if label := _jump (text , self . _use_label ):
181239 line = _Jump (text , label )
182240 label .predecessors .append (line )
183241 return line
184- if line := _label (text , use_label ):
242+ if line := _label (text , self . _use_label ):
185243 assert line .text == ""
186244 line .text = text
187245 return line
@@ -191,66 +249,39 @@ def new_line(text: str) -> _Line:
191249 return _Noise (text )
192250 return _Line (text )
193251
194- # Build graph:
195- lines = []
196- line = _Noise ("" ) # Dummy.
197- with path .open () as file :
198- for i , text in enumerate (file ):
199- new = new_line (text )
200- if not isinstance (line , (_Jump , _Return )):
201- new .predecessors .append (line )
202- lines .append (new )
203- line = new
204- for i , line in enumerate (reversed (lines )):
205- if not isinstance (line , (_Label , _Noise )):
206- break
207- new = new_line ("_JIT_CONTINUE:\n " )
208- lines .insert (len (lines ) - i , new )
209- line = new
210- # Mark hot lines:
211- todo = labels ["_JIT_CONTINUE" ].predecessors .copy ()
212- while todo :
213- line = todo .pop ()
214- line .hot = True
215- for predecessor in line .predecessors :
216- if not predecessor .hot :
217- todo .append (predecessor )
218- for pair in itertools .pairwise (
219- filter (lambda line : line .text and not isinstance (line , _Noise ), lines )
220- ):
221- match pair :
222- case (_Branch (hot = True ) as branch , _Jump (hot = False ) as jump ):
223- branch .invert (jump )
224- jump .hot = True
225- for pair in itertools .pairwise (lines ):
226- match pair :
227- case (_Jump () | _Return (), _):
228- pass
229- case (_Line (hot = True ), _Line (hot = False ) as cold ):
230- cold .hot = True
231- # Reorder blocks:
232- hot = []
233- cold = []
234- for line in lines :
235- if line .hot :
236- hot .append (line )
237- else :
238- cold .append (line )
239- lines = hot + cold
240- # Remove zero-length jumps:
241- again = True
242- while again :
243- again = False
244- for pair in itertools .pairwise (
245- filter (lambda line : line .text and not isinstance (line , _Noise ), lines )
246- ):
247- match pair :
248- case (_Jump (target = target ) as jump , label ) if target is label :
249- jump .remove ()
250- again = True
251- # Write new assembly:
252- with path .open ("w" ) as file :
253- file .writelines (line .text for line in lines )
252+ def _dump (self ) -> str :
253+ return "" .join (line .text for line in self )
254+
255+ def _break_on (self , name : str ) -> None :
256+ if self ._path .stem == name :
257+ print (self ._dump ())
258+ breakpoint ()
259+
260+ def run (self ) -> None :
261+ assert not self ._ran
262+ self ._ran = True
263+ last_line = None
264+ for line in self :
265+ if not isinstance (line , (_Label , _Noise )):
266+ last_line = line
267+ assert last_line is not None
268+ new = self ._new_line (f".balign { self ._alignment } \n " )
269+ new .link = last_line .link
270+ last_line .link = new
271+ new = self ._new_line ("_JIT_CONTINUE:\n " )
272+ new .link = last_line .link
273+ last_line .link = new
274+ # Mark hot lines and optimize:
275+ recursion_limit = sys .getrecursionlimit ()
276+ sys .setrecursionlimit (10_000 )
277+ try :
278+ self ._labels ["_JIT_CONTINUE" ].heat ()
279+ # self._break_on("_BUILD_TUPLE")
280+ self ._lines .optimize ()
281+ finally :
282+ sys .setrecursionlimit (recursion_limit )
283+ # Write new assembly:
284+ self ._path .write_text (self ._dump ())
254285
255286
256287@dataclasses .dataclass
@@ -377,7 +408,7 @@ async def _compile(
377408 * self .args ,
378409 ]
379410 await _llvm .run ("clang" , args_s , echo = self .verbose )
380- _apply_asm_transformations ( s )
411+ _AssemblyTransformer ( s , self . alignment ). run ( )
381412 args_o = [
382413 f"--target={ self .triple } " ,
383414 "-c" ,
0 commit comments