44import re
55import typing
66
7- branches = {}
8- for op , nop in [
9- ("ja" , "jna" ),
10- ("jae" , "jnae" ),
11- ("jb" , "jnb" ),
12- ("jbe" , "jnbe" ),
13- ("jc" , "jnc" ),
14- ("jcxz" , None ),
15- ("je" , "jne" ),
16- ("jecxz" , None ),
17- ("jg" , "jng" ),
18- ("jge" , "jnge" ),
19- ("jl" , "jnl" ),
20- ("jle" , "jnle" ),
21- ("jo" , "jno" ),
22- ("jp" , "jnp" ),
23- ("jpe" , "jpo" ),
24- ("jrxz" , None ),
25- ("js" , "jns" ),
26- ("jz" , "jnz" ),
27- ("loop" , None ),
28- ("loope" , None ),
29- ("loopne" , None ),
30- ("loopnz" , None ),
31- ("loopz" , None ),
32- ]:
33- branches [op ] = nop
34- if nop :
35- branches [nop ] = op
36-
37-
38- def _get_branch (line : str ) -> str | None :
39- branch = re .match (rf"\s*({ '|' .join (branches )} )\s+([\w\.]+)" , line )
40- return branch and branch [2 ]
41-
42-
43- def _invert_branch (line : str , label : str ) -> str | None :
44- branch = re .match (rf"\s*({ '|' .join (branches )} )\s+([\w\.]+)" , line )
45- assert branch
46- inverted = branches .get (branch [1 ])
47- if not inverted :
48- return None
49- line = line .replace (branch [1 ], inverted , 1 ) # XXX
50- line = line .replace (branch [2 ], label , 1 ) # XXX
51- return line
52-
53-
54- def _get_jump (line : str ) -> str | None :
55- jump = re .match (r"\s*(?:rex64\s+)?jmpq?\s+\*?([\w\.]+)" , line )
56- return jump and jump [1 ]
57-
58-
59- def _get_label (line : str ) -> str | None :
60- label = re .match (r"\s*([\w\.]+):" , line )
61- return label and label [1 ]
62-
63-
64- def _is_return (line : str ) -> bool :
65- return re .match (r"\s*ret\s+" , line ) is not None
66-
67-
68- def _is_noise (line : str ) -> bool :
69- return re .match (r"\s*([#\.]|$)" , line ) is not None
7+ _RE_NEVER_MATCH = re .compile (r"(?!)" )
8+
9+ _X86_BRANCHES = {
10+ "ja" : "jna" ,
11+ "jae" : "jnae" ,
12+ "jb" : "jnb" ,
13+ "jbe" : "jnbe" ,
14+ "jc" : "jnc" ,
15+ "jcxz" : None ,
16+ "je" : "jne" ,
17+ "jecxz" : None ,
18+ "jg" : "jng" ,
19+ "jge" : "jnge" ,
20+ "jl" : "jnl" ,
21+ "jle" : "jnle" ,
22+ "jo" : "jno" ,
23+ "jp" : "jnp" ,
24+ "jpe" : "jpo" ,
25+ "jrxz" : None ,
26+ "js" : "jns" ,
27+ "jz" : "jnz" ,
28+ "loop" : None ,
29+ "loope" : None ,
30+ "loopne" : None ,
31+ "loopnz" : None ,
32+ "loopz" : None ,
33+ }
34+ _X86_BRANCHES |= {v : k for k , v in _X86_BRANCHES .items () if v }
7035
7136
7237@dataclasses .dataclass
@@ -91,57 +56,111 @@ def resolve(self) -> typing.Self:
9156 return self
9257
9358
94- class _Labels (dict ):
95- def __missing__ (self , key : str ) -> _Block :
96- self [key ] = _Block (key )
97- return self [key ]
98-
99-
10059@dataclasses .dataclass
10160class Optimizer :
10261
10362 path : pathlib .Path
10463 _ : dataclasses .KW_ONLY
10564 prefix : str = ""
10665 _graph : _Block = dataclasses .field (init = False )
107- _labels : _Labels = dataclasses .field (init = False , default_factory = _Labels )
108-
109- _re_branch : typing .ClassVar [re .Pattern [str ]] # Two groups: instruction and target.
110- _re_jump : typing .ClassVar [re .Pattern [str ]] # One group: target.
111- _re_return : typing .ClassVar [re .Pattern [str ]] # No groups.
66+ _labels : dict = dataclasses .field (init = False , default_factory = dict )
67+ _alignment : typing .ClassVar [int ] = 1
68+ _branches : typing .ClassVar [dict [str , str | None ]] = {}
69+ _re_branch : typing .ClassVar [re .Pattern [str ]] = (
70+ _RE_NEVER_MATCH # Two groups: instruction and target.
71+ )
72+ _re_jump : typing .ClassVar [re .Pattern [str ]] = _RE_NEVER_MATCH # One group: target.
73+ _re_label : typing .ClassVar [re .Pattern [str ]] = re .compile (
74+ r"\s*(?P<label>[\w\.]+):"
75+ ) # One group: label.
76+ _re_noise : typing .ClassVar [re .Pattern [str ]] = re .compile (
77+ r"\s*(?:[#\.]|$)"
78+ ) # No groups.
79+ _re_return : typing .ClassVar [re .Pattern [str ]] = _RE_NEVER_MATCH # No groups.
11280
11381 def __post_init__ (self ) -> None :
11482 text = self ._preprocess (self .path .read_text ())
11583 self ._graph = block = self ._new_block ()
11684 for line in text .splitlines ():
117- if label := _get_label (line ):
118- block .link = block = self ._labels [label ]
85+ if label := self ._parse_label (line ):
86+ block .link = block = self ._lookup_label (label )
87+ block .noise .append (line )
88+ continue
11989 elif block .target or not block .fallthrough :
12090 block .link = block = self ._new_block ()
121- if _is_noise ( line ) or _get_label (line ):
91+ if self . _parse_noise (line ):
12292 if block .instructions :
12393 block .link = block = self ._new_block ()
12494 block .noise .append (line )
12595 continue
12696 block .instructions .append (line )
127- if target := _get_branch (line ):
128- block .target = self ._labels [ target ]
97+ if target := self . _parse_branch (line ):
98+ block .target = self ._lookup_label ( target )
12999 assert block .fallthrough
130- elif target := _get_jump (line ):
131- block .target = self ._labels [ target ]
100+ elif target := self . _parse_jump (line ):
101+ block .target = self ._lookup_label ( target )
132102 block .fallthrough = False
133- elif _is_return (line ):
103+ elif self . _parse_return (line ):
134104 assert not block .target
135105 block .fallthrough = False
136106
137- def _new_block (self , label : str | None = None ) -> _Block :
138- if not label :
139- label = f"{ self .prefix } _JIT_LABEL_{ len (self ._labels )} "
140- assert label not in self ._labels , label
141- block = self ._labels [label ] = _Block (label , [f"{ label } :" ])
107+ @classmethod
108+ def _parse_branch (cls , line : str ) -> str | None :
109+ branch = cls ._re_branch .match (line )
110+ return branch and branch ["target" ]
111+
112+ @classmethod
113+ def _parse_jump (cls , line : str ) -> str | None :
114+ jump = cls ._re_jump .match (line )
115+ return jump and jump ["target" ]
116+
117+ @classmethod
118+ def _parse_label (cls , line : str ) -> str | None :
119+ label = cls ._re_label .match (line )
120+ return label and label ["label" ]
121+
122+ @classmethod
123+ def _parse_noise (cls , line : str ) -> bool :
124+ return cls ._re_noise .match (line ) is not None
125+
126+ @classmethod
127+ def _parse_return (cls , line : str ) -> bool :
128+ return cls ._re_return .match (line ) is not None
129+
130+ @classmethod
131+ def _invert_branch (cls , line : str , target : str ) -> str | None :
132+ branch = cls ._re_branch .match (line )
133+ assert branch
134+ inverted = cls ._branches .get (branch ["instruction" ])
135+ if not inverted :
136+ return None
137+ (a , b ), (c , d ) = branch .span ("instruction" ), branch .span ("target" )
138+ return "" .join ([line [:a ], inverted , line [b :c ], target , line [d :]])
139+
140+ @classmethod
141+ def _thread_jump (cls , line : str , target : str ) -> str :
142+ jump = cls ._re_jump .match (line ) or cls ._re_branch .match (line )
143+ assert jump
144+ a , b = jump .span ("target" )
145+ return "" .join ([line [:a ], target , line [b :]])
146+
147+ @staticmethod
148+ def _create_jump (target : str ) -> str | None :
149+ return None
150+
151+ def _lookup_label (self , label : str ) -> _Block :
152+ if label not in self ._labels :
153+ self ._labels [label ] = _Block (label )
154+ return self ._labels [label ]
155+
156+ def _new_block (self ) -> _Block :
157+ label = f"{ self .prefix } _JIT_LABEL_{ len (self ._labels )} "
158+ block = self ._lookup_label (label )
159+ block .noise .append (f"{ label } :" )
142160 return block
143161
144- def _preprocess (self , text : str ) -> str :
162+ @staticmethod
163+ def _preprocess (text : str ) -> str :
145164 return text
146165
147166 def _blocks (self ) -> typing .Generator [_Block , None , None ]:
@@ -159,9 +178,11 @@ def _insert_continue_label(self) -> None:
159178 for end in reversed (list (self ._blocks ())):
160179 if end .instructions :
161180 break
162- continuation = self ._labels [f"{ self .prefix } _JIT_CONTINUE" ]
181+ align = self ._new_block ()
182+ align .noise .append (f"\t .balign\t { self ._alignment } " )
183+ continuation = self ._lookup_label (f"{ self .prefix } _JIT_CONTINUE" )
163184 continuation .noise .append (f"{ continuation .label } :" )
164- end .link , continuation .link = continuation , end .link
185+ end .link , align . link , continuation .link = align , continuation , end .link
165186
166187 def _mark_hot_blocks (self ) -> None :
167188 predecessors = collections .defaultdict (set )
@@ -170,7 +191,7 @@ def _mark_hot_blocks(self) -> None:
170191 predecessors [block .target ].add (block )
171192 if block .fallthrough and block .link :
172193 predecessors [block .link ].add (block )
173- todo = [self ._labels [ f"{ self .prefix } _JIT_CONTINUE" ] ]
194+ todo = [self ._lookup_label ( f"{ self .prefix } _JIT_CONTINUE" ) ]
174195 while todo :
175196 block = todo .pop ()
176197 block .hot = True
@@ -185,43 +206,41 @@ def _invert_hot_branches(self) -> None:
185206 if (
186207 block .fallthrough
187208 and block .target
188- and block .link
189209 and block .target .hot
210+ and block .link
190211 and not block .link .hot
191212 ):
192- # Turn...
193- # branch hot
194- # ...into..
195- # opposite-branch ._JIT_LABEL_N
196- # jmp hot
197- # ._JIT_LABEL_N:
198213 label_block = self ._new_block ()
199- inverted = _invert_branch (block .instructions [- 1 ], label_block .label )
200- if inverted is None :
214+ branch = block .instructions [- 1 ]
215+ inverted = self ._invert_branch (branch , label_block .label )
216+ jump = self ._create_jump (block .target .label )
217+ if inverted is None or jump is None :
201218 continue
202219 jump_block = self ._new_block ()
203- jump_block .instructions .append (f" \t jmp \t { block . target . label } " )
220+ jump_block .instructions .append (jump )
204221 jump_block .target = block .target
205222 jump_block .fallthrough = False
206223 block .instructions [- 1 ] = inverted
207224 block .target = label_block
208- label_block .link = block .link
209- jump_block .link = label_block
210- block .link = jump_block
225+ block .link , jump_block .link , label_block .link = (
226+ jump_block ,
227+ label_block ,
228+ block .link ,
229+ )
211230
212231 def _thread_jumps (self ) -> None :
213232 for block in self ._blocks ():
214233 while block .target :
215- label = block .target .label
216234 target = block .target .resolve ()
217235 if (
218236 not target .fallthrough
219237 and target .target
220238 and len (target .instructions ) == 1
221239 ):
222- block .instructions [- 1 ] = block .instructions [- 1 ].replace (
223- label , target .target .label
224- ) # XXX
240+ jump = block .instructions [- 1 ]
241+ block .instructions [- 1 ] = self ._thread_jump (
242+ jump , target .target .label
243+ )
225244 block .target = target .target
226245 else :
227246 break
@@ -273,14 +292,24 @@ def run(self) -> None:
273292 self .path .write_text ("\n " .join (self ._lines ()))
274293
275294
295+ class OptimizerAArch64 (Optimizer ):
296+ # TODO: @diegorusso
297+ _alignment = 8
298+
299+
276300class OptimizerX86 (Optimizer ):
277301
302+ _branches = _X86_BRANCHES
278303 _re_branch = re .compile (
279- rf"\s*(?P<instruction>{ '|' .join (branches )} )\s+(?P<target>[\w\.]+)"
304+ rf"\s*(?P<instruction>{ '|' .join (_X86_BRANCHES )} )\s+(?P<target>[\w\.]+)"
280305 )
281306 _re_jump = re .compile (r"\s*jmp\s+(?P<target>[\w\.]+)" )
282307 _re_return = re .compile (r"\s*ret\b" )
283308
309+ @staticmethod
310+ def _create_jump (target : str ) -> str :
311+ return f"\t jmp\t { target } "
312+
284313
285314class OptimizerX86Windows (OptimizerX86 ):
286315
0 commit comments