@@ -543,6 +543,153 @@ def visitModule(self, mod):
543543 self .emit ("}" , depth )
544544
545545
546+ class VisitorStructsDefVisitor (StructVisitor ):
547+ def visitModule (self , mod , depth ):
548+ for dfn in mod .dfns :
549+ self .visit (dfn , depth )
550+
551+ def visitProduct (self , product , name , depth ):
552+ pass
553+
554+ def visitSum (self , sum , name , depth ):
555+ if not is_simple (sum ):
556+ typeinfo = self .typeinfo [name ]
557+ if not sum .attributes :
558+ return
559+ for t in sum .types :
560+ typename = t .name + "Node"
561+
562+ has_userdata = any (
563+ getattr (self .typeinfo .get (f .type ), "has_userdata" , False )
564+ for f in t .fields
565+ )
566+ self .emit (
567+ f"pub struct { typename } Data<{ 'U=()' if has_userdata else '' } > {{" ,
568+ depth ,
569+ )
570+ for f in t .fields :
571+ self .visit (f , typeinfo , "pub " , depth + 1 , t .name )
572+ self .emit ("}" , depth )
573+ self .emit (
574+ f"pub type { typename } <U = ()> = Located<{ typename } Data<{ 'U' if has_userdata else '' } >, U>;" ,
575+ depth ,
576+ )
577+ self .emit ("" , depth )
578+
579+
580+ class VisitorTraitDefVisitor (StructVisitor ):
581+ def visitModule (self , mod , depth ):
582+ self .emit ("pub trait Visitor<U=()> {" , depth )
583+ for dfn in mod .dfns :
584+ self .visit (dfn , depth + 1 )
585+ self .emit ("}" , depth )
586+
587+ def visitType (self , type , depth = 0 ):
588+ self .visit (type .value , type .name , depth )
589+
590+ def emit_visitor (self , nodename , rusttype , depth ):
591+ self .emit (f"fn visit_{ nodename } (&mut self, node: { rusttype } ) {{" , depth )
592+ self .emit (f"self.generic_visit_{ nodename } (node);" , depth + 1 )
593+ self .emit ("}" , depth )
594+
595+ def emit_generic_visitor_signature (self , nodename , rusttype , depth ):
596+ self .emit (f"fn generic_visit_{ nodename } (&mut self, node: { rusttype } ) {{" , depth )
597+
598+ def emit_empty_generic_visitor (self , nodename , rusttype , depth ):
599+ self .emit_generic_visitor_signature (nodename , rusttype , depth )
600+ self .emit ("}" , depth )
601+
602+ def simple_sum (self , sum , name , depth ):
603+ rustname = get_rust_type (name )
604+ self .emit_visitor (name , rustname , depth )
605+ self .emit_empty_generic_visitor (name , rustname , depth )
606+
607+ def visit_match_for_type (self , enumname , type_ , depth ):
608+ self .emit (f"{ enumname } ::{ type_ .name } {{" , depth )
609+ for field in type_ .fields :
610+ self .emit (f"{ rust_field (field .name )} ," , depth + 1 )
611+ self .emit (f"}} => self.visit_{ type_ .name } (" , depth )
612+ self .emit (f"{ type_ .name } Node {{" , depth + 2 )
613+ self .emit ("location: node.location," , depth + 2 )
614+ self .emit ("end_location: node.end_location," , depth + 2 )
615+ self .emit ("custom: node.custom," , depth + 2 )
616+ self .emit (f"node: { type_ .name } NodeData {{" , depth + 2 )
617+ for field in type_ .fields :
618+ self .emit (f"{ rust_field (field .name )} ," , depth + 3 )
619+ self .emit ("}," , depth + 2 )
620+ self .emit ("}" , depth + 1 )
621+ self .emit (")," , depth )
622+
623+ def visit_sumtype (self , type_ , depth ):
624+ rustname = get_rust_type (type_ .name ) + "Node"
625+ self .emit_visitor (type_ .name , rustname , depth )
626+ self .emit_generic_visitor_signature (type_ .name , rustname , depth )
627+ for f in type_ .fields :
628+ fieldname = rust_field (f .name )
629+ fieldtype = self .typeinfo .get (f .type )
630+ if not (fieldtype and fieldtype .has_userdata ):
631+ continue
632+
633+ if f .opt :
634+ self .emit (f"if let Some(value) = node.node.{ fieldname } {{" , depth + 1 )
635+ elif f .seq :
636+ iterable = f"node.node.{ fieldname } "
637+ if type_ .name == "Dict" and f .name == "keys" :
638+ iterable = f"{ iterable } .into_iter().flatten()"
639+ self .emit (f"for value in { iterable } {{" , depth + 1 )
640+ else :
641+ self .emit ("{" , depth + 1 )
642+ self .emit (f"let value = node.node.{ fieldname } ;" , depth + 2 )
643+
644+ variable = "value"
645+ if fieldtype .boxed and (not f .seq or f .opt ):
646+ variable = "*" + variable
647+ self .emit (f"self.visit_{ fieldtype .name } ({ variable } );" , depth + 2 )
648+
649+ self .emit ("}" , depth + 1 )
650+
651+ self .emit ("}" , depth )
652+
653+ def sum_with_constructors (self , sum , name , depth ):
654+ if not sum .attributes :
655+ return
656+
657+ rustname = enumname = get_rust_type (name )
658+ if sum .attributes :
659+ enumname += "Kind"
660+ self .emit_visitor (name , rustname , depth )
661+ self .emit_generic_visitor_signature (name , rustname , depth )
662+ depth += 1
663+ self .emit ("match node.node {" , depth )
664+ for t in sum .types :
665+ self .visit_match_for_type (enumname , t , depth + 1 )
666+ self .emit ("}" , depth )
667+ depth -= 1
668+ self .emit ("}" , depth )
669+
670+ # Now for the visitors for the types
671+ for t in sum .types :
672+ self .visit_sumtype (t , depth )
673+
674+ def visitProduct (self , product , name , depth ):
675+ rusttype = get_rust_type (name )
676+ self .emit_visitor (name , rusttype , depth )
677+ self .emit_empty_generic_visitor (name , rusttype , depth )
678+
679+
680+ class VisitorModuleVisitor (EmitVisitor ):
681+ def visitModule (self , mod ):
682+ depth = 0
683+ self .emit ('#[cfg(feature = "visitor")]' , depth )
684+ self .emit ("#[allow(unused_variables, non_snake_case)]" , depth )
685+ self .emit ("pub mod visitor {" , depth )
686+ self .emit ("use super::*;" , depth + 1 )
687+ VisitorStructsDefVisitor (self .file , self .typeinfo ).visit (mod , depth + 1 )
688+ VisitorTraitDefVisitor (self .file , self .typeinfo ).visit (mod , depth + 1 )
689+ self .emit ("}" , depth )
690+ self .emit ("" , depth )
691+
692+
546693class ClassDefVisitor (EmitVisitor ):
547694 def visitModule (self , mod ):
548695 for dfn in mod .dfns :
@@ -811,7 +958,11 @@ def write_generic_def(mod, typeinfo, f):
811958 )
812959 )
813960
814- c = ChainOfVisitors (StructVisitor (f , typeinfo ), FoldModuleVisitor (f , typeinfo ))
961+ c = ChainOfVisitors (
962+ StructVisitor (f , typeinfo ),
963+ FoldModuleVisitor (f , typeinfo ),
964+ VisitorModuleVisitor (f , typeinfo ),
965+ )
815966 c .visit (mod )
816967
817968
0 commit comments