1- import collections
21import dataclasses
32import pathlib
43import re
@@ -44,12 +43,6 @@ class _Block:
4443 fallthrough : bool = True
4544 hot : bool = False
4645
47- def __eq__ (self , other : object ) -> bool :
48- return self is other
49-
50- def __hash__ (self ) -> int :
51- return super ().__hash__ ()
52-
5346 def resolve (self ) -> typing .Self :
5447 while self .link and not self .instructions :
5548 self = self .link
@@ -71,7 +64,7 @@ class Optimizer:
7164 )
7265 _re_jump : typing .ClassVar [re .Pattern [str ]] = _RE_NEVER_MATCH # One group: target.
7366 _re_label : typing .ClassVar [re .Pattern [str ]] = re .compile (
74- r" \s*(?P<label>[\w.]+):" # One group: label.
67+ r' \s*(?P<label>[\w."$?@ ]+):' # One group: label.
7568 )
7669 _re_noise : typing .ClassVar [re .Pattern [str ]] = re .compile (
7770 r"\s*(?:[#.]|$)" # No groups.
@@ -104,6 +97,10 @@ def __post_init__(self) -> None:
10497 assert not block .target
10598 block .fallthrough = False
10699
100+ @staticmethod
101+ def _preprocess (text : str ) -> str :
102+ return text
103+
107104 @classmethod
108105 def _invert_branch (cls , line : str , target : str ) -> str | None :
109106 match = cls ._re_branch .match (line )
@@ -126,10 +123,6 @@ def _lookup_label(self, label: str) -> _Block:
126123 self ._labels [label ] = _Block (label )
127124 return self ._labels [label ]
128125
129- @staticmethod
130- def _preprocess (text : str ) -> str :
131- return text
132-
133126 def _blocks (self ) -> typing .Generator [_Block , None , None ]:
134127 block = self ._graph
135128 while block :
@@ -143,6 +136,11 @@ def _body(self) -> str:
143136 lines .extend (block .instructions )
144137 return "\n " .join (lines )
145138
139+ def _predecessors (self , block : _Block ) -> typing .Generator [_Block , None , None ]:
140+ for block in self ._blocks ():
141+ if block .target is block or (block .fallthrough and block .link is block ):
142+ yield block
143+
146144 def _insert_continue_label (self ) -> None :
147145 for end in reversed (list (self ._blocks ())):
148146 if end .instructions :
@@ -155,36 +153,39 @@ def _insert_continue_label(self) -> None:
155153 end .link , align .link , continuation .link = align , continuation , end .link
156154
157155 def _mark_hot_blocks (self ) -> None :
158- predecessors = collections .defaultdict (list )
159- for block in self ._blocks ():
160- if block .target :
161- predecessors [block .target ].append (block )
162- if block .fallthrough and block .link :
163- predecessors [block .link ].append (block )
164156 todo = [self ._lookup_label (f"{ self .prefix } _JIT_CONTINUE" )]
165157 while todo :
166158 block = todo .pop ()
167159 block .hot = True
168160 todo .extend (
169161 predecessor
170- for predecessor in predecessors [ block ]
162+ for predecessor in self . _predecessors ( block )
171163 if not predecessor .hot
172164 )
173165
174166 def _invert_hot_branches (self ) -> None :
167+ # Turn:
168+ # branch <hot>
169+ # jump <cold>
170+ # Into:
171+ # opposite-branch <cold>
172+ # jump <hot>
175173 for branch in self ._blocks ():
176- jump = branch .link
174+ link = branch .link
175+ if link is None :
176+ continue
177+ jump = link .resolve ()
177178 if (
178179 # block ends with a branch to hot code...
179180 branch .target
180181 and branch .fallthrough
181182 and branch .target .hot
182- # ...followed by a jump to cold code:
183- and jump
183+ # ...followed by a jump to cold code with no other predecessors:
184184 and jump .target
185185 and not jump .fallthrough
186186 and not jump .target .hot
187187 and len (jump .instructions ) == 1
188+ and list (self ._predecessors (jump )) == [branch ]
188189 ):
189190 assert jump .target .label
190191 assert branch .target .label
0 commit comments