44import re
55import typing
66
7- _RE_NEVER_MATCH = re .compile (r"(?!)" )
7+ _RE_NEVER_MATCH = re .compile (r"(?!)" ) # Same as saying "not string.startswith('')".
88
99_X86_BRANCHES = {
1010 "ja" : "jna" ,
3636
3737@dataclasses .dataclass
3838class _Block :
39- label : str
39+ label : str | None = None
4040 noise : list [str ] = dataclasses .field (default_factory = list )
4141 instructions : list [str ] = dataclasses .field (default_factory = list )
4242 target : typing .Self | None = None
@@ -71,94 +71,61 @@ class Optimizer:
7171 )
7272 _re_jump : typing .ClassVar [re .Pattern [str ]] = _RE_NEVER_MATCH # One group: target.
7373 _re_label : typing .ClassVar [re .Pattern [str ]] = re .compile (
74- r' \s*(?P<label>[\w"$.?@ ]+):'
75- ) # One group: label.
74+ r" \s*(?P<label>[\w. ]+):" # One group: label.
75+ )
7676 _re_noise : typing .ClassVar [re .Pattern [str ]] = re .compile (
77- r"\s*(?:[#.]|$)"
78- ) # No groups.
77+ r"\s*(?:[#.]|$)" # No groups.
78+ )
7979 _re_return : typing .ClassVar [re .Pattern [str ]] = _RE_NEVER_MATCH # No groups.
8080
8181 def __post_init__ (self ) -> None :
8282 text = self ._preprocess (self .path .read_text ())
83- self ._graph = block = self . _new_block ()
83+ self ._graph = block = _Block ()
8484 for line in text .splitlines ():
85- if label := self ._parse_label (line ):
86- block .link = block = self ._lookup_label (label )
85+ if match := self ._re_label . match (line ):
86+ block .link = block = self ._lookup_label (match [ " label" ] )
8787 block .noise .append (line )
8888 continue
89- elif block .target or not block .fallthrough :
90- block .link = block = self ._new_block ()
91- if self ._parse_noise (line ):
89+ if self ._re_noise .match (line ):
9290 if block .instructions :
93- block .link = block = self . _new_block ()
91+ block .link = block = _Block ()
9492 block .noise .append (line )
9593 continue
94+ if block .target or not block .fallthrough :
95+ block .link = block = _Block ()
9696 block .instructions .append (line )
97- if target := self ._parse_branch (line ):
98- block .target = self ._lookup_label (target )
97+ if match := self ._re_branch . match (line ):
98+ block .target = self ._lookup_label (match [ " target" ] )
9999 assert block .fallthrough
100- elif target := self ._parse_jump (line ):
101- block .target = self ._lookup_label (target )
100+ elif match := self ._re_jump . match (line ):
101+ block .target = self ._lookup_label (match [ " target" ] )
102102 block .fallthrough = False
103- elif self ._parse_return (line ):
103+ elif self ._re_return . match (line ):
104104 assert not block .target
105105 block .fallthrough = False
106106
107- @classmethod
108- def _parse_branch (cls , line : str ) -> str | None :
109- branch = cls ._re_branch .match (line )
110- return branch and branch ["target" ]
111-
112- @classmethod
113- def _parse_jump (cls , line : str ) -> str | None :
114- jump = cls ._re_jump .match (line )
115- return jump and jump ["target" ]
116-
117- @classmethod
118- def _parse_label (cls , line : str ) -> str | None :
119- label = cls ._re_label .match (line )
120- return label and label ["label" ]
121-
122- @classmethod
123- def _parse_noise (cls , line : str ) -> bool :
124- return cls ._re_noise .match (line ) is not None
125-
126- @classmethod
127- def _parse_return (cls , line : str ) -> bool :
128- return cls ._re_return .match (line ) is not None
129-
130107 @classmethod
131108 def _invert_branch (cls , line : str , target : str ) -> str | None :
132- branch = cls ._re_branch .match (line )
133- assert branch
134- inverted = cls ._branches .get (branch ["instruction" ])
109+ match = cls ._re_branch .match (line )
110+ assert match
111+ inverted = cls ._branches .get (match ["instruction" ])
135112 if not inverted :
136113 return None
137- (a , b ), (c , d ) = branch .span ("instruction" ), branch .span ("target" )
114+ (a , b ), (c , d ) = match .span ("instruction" ), match .span ("target" )
138115 return "" .join ([line [:a ], inverted , line [b :c ], target , line [d :]])
139116
140117 @classmethod
141- def _thread_jump (cls , line : str , target : str ) -> str :
142- jump = cls ._re_jump . match ( line ) or cls . _re_branch .match (line )
143- assert jump
144- a , b = jump .span ("target" )
118+ def _update_jump (cls , line : str , target : str ) -> str :
119+ match = cls ._re_jump .match (line )
120+ assert match
121+ a , b = match .span ("target" )
145122 return "" .join ([line [:a ], target , line [b :]])
146123
147- @staticmethod
148- def _create_jump (target : str ) -> str | None :
149- return None
150-
151124 def _lookup_label (self , label : str ) -> _Block :
152125 if label not in self ._labels :
153126 self ._labels [label ] = _Block (label )
154127 return self ._labels [label ]
155128
156- def _new_block (self ) -> _Block :
157- label = f"{ self .prefix } _JIT_LABEL_{ len (self ._labels )} "
158- block = self ._lookup_label (label )
159- block .noise .append (f"{ label } :" )
160- return block
161-
162129 @staticmethod
163130 def _preprocess (text : str ) -> str :
164131 return text
@@ -169,28 +136,31 @@ def _blocks(self) -> typing.Generator[_Block, None, None]:
169136 yield block
170137 block = block .link
171138
172- def _lines (self ) -> typing .Generator [str , None , None ]:
139+ def _body (self ) -> str :
140+ lines = []
173141 for block in self ._blocks ():
174- yield from block .noise
175- yield from block .instructions
142+ lines .extend (block .noise )
143+ lines .extend (block .instructions )
144+ return "\n " .join (lines )
176145
177146 def _insert_continue_label (self ) -> None :
178147 for end in reversed (list (self ._blocks ())):
179148 if end .instructions :
180149 break
181- align = self . _new_block ()
150+ align = _Block ()
182151 align .noise .append (f"\t .balign\t { self ._alignment } " )
183152 continuation = self ._lookup_label (f"{ self .prefix } _JIT_CONTINUE" )
153+ assert continuation .label
184154 continuation .noise .append (f"{ continuation .label } :" )
185155 end .link , align .link , continuation .link = align , continuation , end .link
186156
187157 def _mark_hot_blocks (self ) -> None :
188- predecessors = collections .defaultdict (set )
158+ predecessors = collections .defaultdict (list )
189159 for block in self ._blocks ():
190160 if block .target :
191- predecessors [block .target ].add (block )
161+ predecessors [block .target ].append (block )
192162 if block .fallthrough and block .link :
193- predecessors [block .link ].add (block )
163+ predecessors [block .link ].append (block )
194164 todo = [self ._lookup_label (f"{ self .prefix } _JIT_CONTINUE" )]
195165 while todo :
196166 block = todo .pop ()
@@ -202,64 +172,33 @@ def _mark_hot_blocks(self) -> None:
202172 )
203173
204174 def _invert_hot_branches (self ) -> None :
205- for block in self ._blocks ():
175+ for branch in self ._blocks ():
176+ jump = branch .link
206177 if (
207- block .fallthrough
208- and block .target
209- and block .target .hot
210- and block .link
211- and not block .link .hot
178+ # block ends with a branch to hot code...
179+ branch .target
180+ and branch .fallthrough
181+ and branch .target .hot
182+ # ...followed by a jump to cold code:
183+ and jump
184+ and jump .target
185+ and not jump .fallthrough
186+ and not jump .target .hot
187+ and len (jump .instructions ) == 1
212188 ):
213- label_block = self ._new_block ()
214- branch = block .instructions [- 1 ]
215- inverted = self ._invert_branch (branch , label_block .label )
216- jump = self ._create_jump (block .target .label )
217- if inverted is None or jump is None :
189+ assert jump .target .label
190+ assert branch .target .label
191+ inverted = self ._invert_branch (
192+ branch .instructions [- 1 ], jump .target .label
193+ )
194+ if inverted is None :
218195 continue
219- jump_block = self ._new_block ()
220- jump_block .instructions .append (jump )
221- jump_block .target = block .target
222- jump_block .fallthrough = False
223- block .instructions [- 1 ] = inverted
224- block .target = label_block
225- block .link , jump_block .link , label_block .link = (
226- jump_block ,
227- label_block ,
228- block .link ,
196+ branch .instructions [- 1 ] = inverted
197+ jump .instructions [- 1 ] = self ._update_jump (
198+ jump .instructions [- 1 ], branch .target .label
229199 )
230-
231- def _thread_jumps (self ) -> None :
232- for block in self ._blocks ():
233- while block .target :
234- target = block .target .resolve ()
235- if (
236- not target .fallthrough
237- and target .target
238- and len (target .instructions ) == 1
239- ):
240- jump = block .instructions [- 1 ]
241- block .instructions [- 1 ] = self ._thread_jump (
242- jump , target .target .label
243- )
244- block .target = target .target
245- else :
246- break
247-
248- def _remove_dead_code (self ) -> None :
249- reachable = set ()
250- todo = [self ._graph ]
251- while todo :
252- block = todo .pop ()
253- reachable .add (block )
254- if block .target and block .target not in reachable :
255- todo .append (block .target )
256- if block .fallthrough and block .link and block .link not in reachable :
257- todo .append (block .link )
258- for block in self ._blocks ():
259- if block not in reachable :
260- block .target = None
261- block .fallthrough = True
262- block .instructions .clear ()
200+ branch .target , jump .target = jump .target , branch .target
201+ jump .hot = True
263202
264203 def _remove_redundant_jumps (self ) -> None :
265204 for block in self ._blocks ():
@@ -272,26 +211,12 @@ def _remove_redundant_jumps(self) -> None:
272211 block .fallthrough = True
273212 block .instructions .pop ()
274213
275- def _remove_unused_labels (self ) -> None :
276- used = set ()
277- for block in self ._blocks ():
278- if block .target :
279- used .add (block .target )
280- for block in self ._blocks ():
281- if block not in used and block .label .startswith (
282- f"{ self .prefix } _JIT_LABEL_"
283- ):
284- del block .noise [0 ]
285-
286214 def run (self ) -> None :
287215 self ._insert_continue_label ()
288216 self ._mark_hot_blocks ()
289217 self ._invert_hot_branches ()
290- self ._thread_jumps ()
291- # self._remove_dead_code() # XXX: Need calls to reason about this!
292218 self ._remove_redundant_jumps ()
293- self ._remove_unused_labels ()
294- self .path .write_text ("\n " .join (self ._lines ()))
219+ self .path .write_text (self ._body ())
295220
296221
297222class OptimizerAArch64 (Optimizer ):
@@ -308,15 +233,12 @@ class OptimizerX86(Optimizer):
308233 _re_jump = re .compile (r"\s*jmp\s+(?P<target>[\w.]+)" )
309234 _re_return = re .compile (r"\s*ret\b" )
310235
311- @staticmethod
312- def _create_jump (target : str ) -> str :
313- return f"\t jmp\t { target } "
314-
315236
316237class OptimizerX86Windows (OptimizerX86 ):
317238
318239 def _preprocess (self , text : str ) -> str :
319240 text = super ()._preprocess (text )
241+ # rex64 jumpq *__imp__JIT_CONTINUE(%rip) -> jmp _JIT_CONTINUE
320242 far_indirect_jump = (
321243 rf"rex64\s+jmpq\s+\*__imp_(?P<target>{ self .prefix } _JIT_\w+)\(%rip\)"
322244 )
0 commit comments