Skip to content

Commit 02f13ab

Browse files
authored
Add Node trait for node type information (#31)
1 parent 10dda12 commit 02f13ab

File tree

4 files changed

+430
-2
lines changed

4 files changed

+430
-2
lines changed

ast/asdl_rs.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,17 @@ def visitSum(self, sum, name, depth):
272272
else:
273273
self.sum_with_constructors(sum, name, depth)
274274

275+
(generics_applied,) = self.apply_generics(name, "R")
276+
self.emit(
277+
f"""
278+
impl{generics_applied} Node for {rust_type_name(name)}{generics_applied} {{
279+
const NAME: &'static str = "{name}";
280+
const FIELD_NAMES: &'static [&'static str] = &[];
281+
}}
282+
""",
283+
depth,
284+
)
285+
275286
def emit_attrs(self, depth):
276287
self.emit("#[derive(Clone, Debug, PartialEq)]", depth)
277288

@@ -327,9 +338,14 @@ def sum_subtype_struct(self, sum_type_info, t, rust_name, depth):
327338
)
328339

329340
self.emit("}", depth)
341+
field_names = [f'"{f.name}"' for f in t.fields]
330342
self.emit(
331343
textwrap.dedent(
332344
f"""
345+
impl<R> Node for {payload_name}<R> {{
346+
const NAME: &'static str = "{t.name}";
347+
const FIELD_NAMES: &'static [&'static str] = &[{', '.join(field_names)}];
348+
}}
333349
impl<R> From<{payload_name}<R>> for {rust_name}<R> {{
334350
fn from(payload: {payload_name}<R>) -> Self {{
335351
{rust_name}::{t.name}(payload)
@@ -389,7 +405,19 @@ def visitProduct(self, product, name, depth):
389405
assert bool(product.attributes) == type_info.no_cfg(self.type_info)
390406
self.emit_range(product.attributes, depth + 1)
391407
self.emit("}", depth)
392-
self.emit("", depth)
408+
409+
field_names = [f'"{f.name}"' for f in product.fields]
410+
self.emit(
411+
f"""
412+
impl<R> Node for {product_name}<R> {{
413+
const NAME: &'static str = "{name}";
414+
const FIELD_NAMES: &'static [&'static str] = &[
415+
{', '.join(field_names)}
416+
];
417+
}}
418+
""",
419+
depth,
420+
)
393421

394422

395423
class FoldTraitDefVisitor(EmitVisitor):
@@ -575,6 +603,12 @@ def visitModule(self, mod, depth):
575603
def visitType(self, type, depth=0):
576604
self.visit(type.value, type.name, depth)
577605

606+
def visitSum(self, sum, name, depth):
607+
if is_simple(sum):
608+
self.simple_sum(sum, name, depth)
609+
else:
610+
self.sum_with_constructors(sum, name, depth)
611+
578612
def emit_visitor(self, nodename, depth, has_node=True):
579613
type_info = self.type_info[nodename]
580614
node_type = type_info.rust_sum_name

0 commit comments

Comments
 (0)