Skip to content

Commit 4de9580

Browse files
thejcannonyouknowone
authored andcommitted
Generate a visitor trait to ast_gen.rs
1 parent 5cf85f0 commit 4de9580

File tree

2 files changed

+153
-1
lines changed

2 files changed

+153
-1
lines changed

ast/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ constant-optimization = ["fold"]
1313
source-code = ["fold"]
1414
fold = []
1515
unparse = ["rustpython-literal"]
16+
visitor = []
1617

1718
[dependencies]
1819
rustpython-parser-core = { workspace = true }

ast/asdl_rs.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,153 @@ def visitModule(self, mod):
543543
self.emit("}", depth)
544544

545545

546+
class VisitorStructsDefVisitor(StructVisitor):
547+
def visitModule(self, mod, depth):
548+
for dfn in mod.dfns:
549+
self.visit(dfn, depth)
550+
551+
def visitProduct(self, product, name, depth):
552+
pass
553+
554+
def visitSum(self, sum, name, depth):
555+
if not is_simple(sum):
556+
typeinfo = self.typeinfo[name]
557+
if not sum.attributes:
558+
return
559+
for t in sum.types:
560+
typename = t.name + "Node"
561+
562+
has_userdata = any(
563+
getattr(self.typeinfo.get(f.type), "has_userdata", False)
564+
for f in t.fields
565+
)
566+
self.emit(
567+
f"pub struct {typename}Data<{'U=()' if has_userdata else ''}> {{",
568+
depth,
569+
)
570+
for f in t.fields:
571+
self.visit(f, typeinfo, "pub ", depth + 1, t.name)
572+
self.emit("}", depth)
573+
self.emit(
574+
f"pub type {typename}<U = ()> = Located<{typename}Data<{'U' if has_userdata else ''}>, U>;",
575+
depth,
576+
)
577+
self.emit("", depth)
578+
579+
580+
class VisitorTraitDefVisitor(StructVisitor):
581+
def visitModule(self, mod, depth):
582+
self.emit("pub trait Visitor<U=()> {", depth)
583+
for dfn in mod.dfns:
584+
self.visit(dfn, depth + 1)
585+
self.emit("}", depth)
586+
587+
def visitType(self, type, depth=0):
588+
self.visit(type.value, type.name, depth)
589+
590+
def emit_visitor(self, nodename, rusttype, depth):
591+
self.emit(f"fn visit_{nodename}(&mut self, node: {rusttype}) {{", depth)
592+
self.emit(f"self.generic_visit_{nodename}(node);", depth + 1)
593+
self.emit("}", depth)
594+
595+
def emit_generic_visitor_signature(self, nodename, rusttype, depth):
596+
self.emit(f"fn generic_visit_{nodename}(&mut self, node: {rusttype}) {{", depth)
597+
598+
def emit_empty_generic_visitor(self, nodename, rusttype, depth):
599+
self.emit_generic_visitor_signature(nodename, rusttype, depth)
600+
self.emit("}", depth)
601+
602+
def simple_sum(self, sum, name, depth):
603+
rustname = get_rust_type(name)
604+
self.emit_visitor(name, rustname, depth)
605+
self.emit_empty_generic_visitor(name, rustname, depth)
606+
607+
def visit_match_for_type(self, enumname, type_, depth):
608+
self.emit(f"{enumname}::{type_.name} {{", depth)
609+
for field in type_.fields:
610+
self.emit(f"{rust_field(field.name)},", depth + 1)
611+
self.emit(f"}} => self.visit_{type_.name}(", depth)
612+
self.emit(f"{type_.name}Node {{", depth + 2)
613+
self.emit("location: node.location,", depth + 2)
614+
self.emit("end_location: node.end_location,", depth + 2)
615+
self.emit("custom: node.custom,", depth + 2)
616+
self.emit(f"node: {type_.name}NodeData {{", depth + 2)
617+
for field in type_.fields:
618+
self.emit(f"{rust_field(field.name)},", depth + 3)
619+
self.emit("},", depth + 2)
620+
self.emit("}", depth + 1)
621+
self.emit("),", depth)
622+
623+
def visit_sumtype(self, type_, depth):
624+
rustname = get_rust_type(type_.name) + "Node"
625+
self.emit_visitor(type_.name, rustname, depth)
626+
self.emit_generic_visitor_signature(type_.name, rustname, depth)
627+
for f in type_.fields:
628+
fieldname = rust_field(f.name)
629+
fieldtype = self.typeinfo.get(f.type)
630+
if not (fieldtype and fieldtype.has_userdata):
631+
continue
632+
633+
if f.opt:
634+
self.emit(f"if let Some(value) = node.node.{fieldname} {{", depth + 1)
635+
elif f.seq:
636+
iterable = f"node.node.{fieldname}"
637+
if type_.name == "Dict" and f.name == "keys":
638+
iterable = f"{iterable}.into_iter().flatten()"
639+
self.emit(f"for value in {iterable} {{", depth + 1)
640+
else:
641+
self.emit("{", depth + 1)
642+
self.emit(f"let value = node.node.{fieldname};", depth + 2)
643+
644+
variable = "value"
645+
if fieldtype.boxed and (not f.seq or f.opt):
646+
variable = "*" + variable
647+
self.emit(f"self.visit_{fieldtype.name}({variable});", depth + 2)
648+
649+
self.emit("}", depth + 1)
650+
651+
self.emit("}", depth)
652+
653+
def sum_with_constructors(self, sum, name, depth):
654+
if not sum.attributes:
655+
return
656+
657+
rustname = enumname = get_rust_type(name)
658+
if sum.attributes:
659+
enumname += "Kind"
660+
self.emit_visitor(name, rustname, depth)
661+
self.emit_generic_visitor_signature(name, rustname, depth)
662+
depth += 1
663+
self.emit("match node.node {", depth)
664+
for t in sum.types:
665+
self.visit_match_for_type(enumname, t, depth + 1)
666+
self.emit("}", depth)
667+
depth -= 1
668+
self.emit("}", depth)
669+
670+
# Now for the visitors for the types
671+
for t in sum.types:
672+
self.visit_sumtype(t, depth)
673+
674+
def visitProduct(self, product, name, depth):
675+
rusttype = get_rust_type(name)
676+
self.emit_visitor(name, rusttype, depth)
677+
self.emit_empty_generic_visitor(name, rusttype, depth)
678+
679+
680+
class VisitorModuleVisitor(EmitVisitor):
681+
def visitModule(self, mod):
682+
depth = 0
683+
self.emit('#[cfg(feature = "visitor")]', depth)
684+
self.emit("#[allow(unused_variables, non_snake_case)]", depth)
685+
self.emit("pub mod visitor {", depth)
686+
self.emit("use super::*;", depth + 1)
687+
VisitorStructsDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
688+
VisitorTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1)
689+
self.emit("}", depth)
690+
self.emit("", depth)
691+
692+
546693
class ClassDefVisitor(EmitVisitor):
547694
def visitModule(self, mod):
548695
for dfn in mod.dfns:
@@ -811,7 +958,11 @@ def write_generic_def(mod, typeinfo, f):
811958
)
812959
)
813960

814-
c = ChainOfVisitors(StructVisitor(f, typeinfo), FoldModuleVisitor(f, typeinfo))
961+
c = ChainOfVisitors(
962+
StructVisitor(f, typeinfo),
963+
FoldModuleVisitor(f, typeinfo),
964+
VisitorModuleVisitor(f, typeinfo),
965+
)
815966
c.visit(mod)
816967

817968

0 commit comments

Comments
 (0)