Skip to content

Commit e52d3d0

Browse files
committed
use dataclasses in x86_ast
that improved typing information and removes a lot of boilerplate The semantics should remain completely the same, with the exception, that the classes are now immutable. It was expected anyway, given that they had a `__hash__` method.
1 parent d4d1bde commit e52d3d0

File tree

1 file changed

+65
-136
lines changed

1 file changed

+65
-136
lines changed

x86_ast.py

Lines changed: 65 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from dataclasses import dataclass
15
from typing import List
2-
from utils import label_name, indent, dedent, indent_stmt
36

7+
from utils import dedent, indent, indent_stmt, label_name
8+
9+
10+
@dataclass(frozen=True)
411
class X86Program:
5-
__match_args__ = ("body",)
6-
def __init__(self, body):
7-
self.body = body
12+
body: dict[str, list[instr]] | list[instr]
13+
814
def __str__(self):
915
result = ''
10-
if type(self.body) == dict:
16+
if isinstance(self.body, dict):
1117
for (l,ss) in self.body.items():
1218
if l == label_name('main'):
1319
result += '\t.globl ' + label_name('main') + '\n'
@@ -24,189 +30,112 @@ def __str__(self):
2430
dedent()
2531
result += '\n'
2632
return result
27-
def __repr__(self):
28-
return 'X86Program(' + repr(self.body) + ')'
2933

34+
@dataclass(frozen=True)
3035
class X86ProgramDefs:
31-
__match_args__ = ("defs",)
32-
def __init__(self, defs):
33-
self.defs = defs
36+
defs: list[ast.FunctionDef]
37+
3438
def __str__(self):
35-
return '\n'.join([str(d) for d in self.defs])
36-
def __repr__(self):
37-
return 'X86ProgramDefs(' + repr(self.defs) + ')'
38-
39+
return "\n".join([str(d) for d in self.defs])
40+
3941
class instr: ...
4042
class arg: ...
4143
class location(arg): ...
42-
44+
45+
@dataclass(frozen=True)
4346
class Instr(instr):
4447
instr: str
4548
args: List[arg]
46-
47-
__match_args__ = ("instr", "args")
48-
def __init__(self, instr, args):
49-
self.instr = instr
50-
self.args = args
49+
5150
def source(self):
5251
return self.args[0]
5352
def target(self):
5453
return self.args[-1]
5554
def __str__(self):
5655
return indent_stmt() + self.instr + ' ' + ', '.join(str(a) for a in self.args) + '\n'
57-
def __repr__(self):
58-
return 'Instr(' + repr(self.instr) + ', ' + repr(self.args) + ')'
59-
56+
57+
@dataclass(frozen=True)
6058
class Callq(instr):
61-
__match_args__ = ("func", "num_args")
62-
def __init__(self, func, num_args):
63-
self.func = func
64-
self.num_args = num_args
59+
func: str
60+
num_args: int
61+
6562
def __str__(self):
6663
return indent_stmt() + 'callq' + ' ' + self.func + '\n'
67-
def __repr__(self):
68-
return 'Callq(' + repr(self.func) + ', ' + repr(self.num_args) + ')'
6964

65+
@dataclass(frozen=True)
7066
class IndirectCallq(instr):
71-
__match_args__ = ("func", "num_args")
72-
def __init__(self, func, num_args):
73-
self.func = func
74-
self.num_args = num_args
67+
func: arg
68+
num_args: int
69+
7570
def __str__(self):
7671
return indent_stmt() + 'callq' + ' *' + str(self.func) + '\n'
77-
def __repr__(self):
78-
return 'IndirectCallq(' + repr(self.func) + ', ' + repr(self.num_args) + ')'
79-
72+
73+
@dataclass(frozen=True)
8074
class JumpIf(instr):
8175
cc: str
8276
label: str
83-
84-
__match_args__ = ("cc", "label")
85-
def __init__(self, cc, label):
86-
self.cc = cc
87-
self.label = label
77+
8878
def __str__(self):
8979
return indent_stmt() + 'j' + self.cc + ' ' + self.label + '\n'
90-
def __repr__(self):
91-
return 'JumpIf(' + repr(self.cc) + ', ' + repr(self.label) + ')'
9280

81+
@dataclass(frozen=True)
9382
class Jump(instr):
9483
label: str
95-
96-
__match_args__ = ("label",)
97-
def __init__(self, label):
98-
self.label = label
84+
9985
def __str__(self):
10086
return indent_stmt() + 'jmp ' + self.label + '\n'
101-
def __repr__(self):
102-
return 'Jump(' + repr(self.label) + ')'
10387

88+
@dataclass(frozen=True)
10489
class IndirectJump(instr):
105-
__match_args__ = ("target",)
106-
def __init__(self, target):
107-
self.target = target
90+
target: location
91+
10892
def __str__(self):
10993
return indent_stmt() + 'jmp *' + str(self.target) + '\n'
110-
def __repr__(self):
111-
return 'IndirectJump(' + repr(self.target) + ')'
112-
94+
95+
@dataclass(frozen=True)
11396
class TailJump(instr):
114-
__match_args__ = ("func","arity")
115-
def __init__(self, func, arity):
116-
self.func = func
117-
self.arity = arity
97+
func: arg
98+
arity: int
99+
118100
def __str__(self):
119101
return indent_stmt() + 'tailjmp ' + str(self.func) + '\n'
120-
def __repr__(self):
121-
return 'TailJump(' + repr(self.func) + ',' + repr(self.arity) + ')'
122-
102+
103+
@dataclass(frozen=True)
123104
class Variable(location):
124-
__match_args__ = ("id",)
125-
def __init__(self, id):
126-
self.id = id
105+
id: str
106+
127107
def __str__(self):
128108
return self.id
129-
def __repr__(self):
130-
return 'Variable(' + repr(self.id) + ')'
131-
def __eq__(self, other):
132-
if isinstance(other, Variable):
133-
return self.id == other.id
134-
else:
135-
return False
136-
def __hash__(self):
137-
return hash(self.id)
138109

110+
@dataclass(frozen=True)
139111
class Immediate(arg):
140-
__match_args__ = ("value",)
141-
def __init__(self, value):
142-
self.value = value
112+
value: int
113+
143114
def __str__(self):
144115
return '$' + str(self.value)
145-
def __repr__(self):
146-
return 'Immediate(' + repr(self.value) + ')'
147-
def __eq__(self, other):
148-
if isinstance(other, Immediate):
149-
return self.value == other.value
150-
else:
151-
return False
152-
def __hash__(self):
153-
return hash(self.value)
154-
116+
117+
@dataclass(frozen=True)
155118
class Reg(location):
156-
__match_args__ = ("id",)
157-
def __init__(self, id):
158-
self.id = id
159-
def __str__(self):
160-
return '%' + self.id
161-
def __repr__(self):
162-
return 'Reg(' + repr(self.id) + ')'
163-
def __eq__(self, other):
164-
if isinstance(other, Reg):
165-
return self.id == other.id
166-
else:
167-
return False
168-
def __hash__(self):
169-
return hash(self.id)
170-
171-
class ByteReg(arg):
172-
__match_args__ = ("id",)
173-
def __init__(self, id):
174-
self.id = id
119+
id: str
120+
175121
def __str__(self):
176122
return '%' + self.id
177-
def __repr__(self):
178-
return 'ByteReg(' + repr(self.id) + ')'
179-
def __eq__(self, other):
180-
if isinstance(other, ByteReg):
181-
return self.id == other.id
182-
else:
183-
return False
184-
def __hash__(self):
185-
return hash(self.id)
186-
123+
124+
@dataclass(frozen=True)
125+
class ByteReg(Reg):
126+
pass
127+
128+
@dataclass(frozen=True)
187129
class Deref(arg):
188-
__match_args__ = ("reg", "offset")
189-
def __init__(self, reg, offset):
190-
self.reg = reg
191-
self.offset = offset
130+
reg: str
131+
offset: int
132+
192133
def __str__(self):
193134
return str(self.offset) + '(%' + self.reg + ')'
194-
def __repr__(self):
195-
return 'Deref(' + repr(self.reg) + ', ' + repr(self.offset) + ')'
196-
def __eq__(self, other):
197-
if isinstance(other, Deref):
198-
return self.reg == other.reg and self.offset == other.offset
199-
else:
200-
return False
201-
def __hash__(self):
202-
return hash((self.reg, self.offset))
203135

136+
@dataclass(frozen=True)
204137
class Global(arg):
205-
__match_args__ = ("name",)
206-
def __init__(self, name):
207-
self.name = name
138+
name: str
139+
208140
def __str__(self):
209-
return str(self.name) + "(%rip)"
210-
def __repr__(self):
211-
return 'Global(' + repr(self.name) + ')'
212-
141+
return self.name + "(%rip)"

0 commit comments

Comments
 (0)