1212import asdl
1313
1414TABSIZE = 4
15- AUTOGEN_MESSAGE = "// File automatically generated by {}.\n "
15+ AUTOGEN_MESSAGE = "// File automatically generated by {}.\n \n "
1616
1717builtin_type_mapping = {
1818 "identifier" : "Ident" ,
@@ -68,6 +68,7 @@ class TypeInfo:
6868 enum_name : Optional [str ]
6969 has_userdata : Optional [bool ]
7070 has_attributes : bool
71+ empty_field : bool
7172 children : set
7273 boxed : bool
7374 product : bool
@@ -78,6 +79,7 @@ def __init__(self, name):
7879 self .enum_name = None
7980 self .has_userdata = None
8081 self .has_attributes = False
82+ self .empty_field = False
8183 self .children = set ()
8284 self .boxed = False
8385 self .product = False
@@ -192,10 +194,9 @@ def visitSum(self, sum, name):
192194 info .has_userdata = False
193195 else :
194196 for t in sum .types :
195- if not t .fields :
196- continue
197197 t_info = TypeInfo (t .name )
198198 t_info .enum_name = name
199+ t_info .empty_field = not t .fields
199200 self .typeinfo [t .name ] = t_info
200201 self .add_children (t .name , t .fields )
201202 if len (sum .types ) > 1 :
@@ -534,14 +535,140 @@ def gen_construction(self, header, fields, footer, depth):
534535class FoldModuleVisitor (EmitVisitor ):
535536 def visitModule (self , mod ):
536537 depth = 0
537- self .emit ('#[cfg(feature = "fold")]' , depth )
538- self .emit ("pub mod fold {" , depth )
539- self .emit ("use super::*;" , depth + 1 )
540- self .emit ("use crate::fold_helpers::Foldable;" , depth + 1 )
541- FoldTraitDefVisitor (self .file , self .typeinfo ).visit (mod , depth + 1 )
542- FoldImplVisitor (self .file , self .typeinfo ).visit (mod , depth + 1 )
538+ self .emit ("use crate::fold_helpers::Foldable;" , depth )
539+ FoldTraitDefVisitor (self .file , self .typeinfo ).visit (mod , depth )
540+ FoldImplVisitor (self .file , self .typeinfo ).visit (mod , depth )
541+
542+
543+ class VisitorTraitDefVisitor (StructVisitor ):
544+ def full_name (self , name ):
545+ typeinfo = self .typeinfo [name ]
546+ if typeinfo .enum_name :
547+ return f"{ typeinfo .enum_name } _{ name } "
548+ else :
549+ return name
550+
551+ def node_type_name (self , name ):
552+ typeinfo = self .typeinfo [name ]
553+ if typeinfo .enum_name :
554+ return f"{ get_rust_type (typeinfo .enum_name )} { get_rust_type (name )} "
555+ else :
556+ return get_rust_type (name )
557+
558+ def visitModule (self , mod , depth ):
559+ self .emit ("pub trait Visitor<U=()> {" , depth )
560+
561+ for dfn in mod .dfns :
562+ self .visit (dfn , depth + 1 )
563+ self .emit ("}" , depth )
564+
565+ def visitType (self , type , depth = 0 ):
566+ self .visit (type .value , type .name , depth )
567+
568+ def emit_visitor (self , nodename , depth , has_node = True ):
569+ typeinfo = self .typeinfo [nodename ]
570+ if has_node :
571+ node_type = typeinfo .rust_sum_name
572+ node_value = "node"
573+ else :
574+ node_type = "()"
575+ node_value = "()"
576+ self .emit (
577+ f"fn visit_{ typeinfo .sum_name } (&mut self, node: { node_type } ) {{" , depth
578+ )
579+ self .emit (f"self.generic_visit_{ typeinfo .sum_name } ({ node_value } )" , depth + 1 )
580+ self .emit ("}" , depth )
581+
582+ def emit_generic_visitor_signature (self , nodename , depth , has_node = True ):
583+ typeinfo = self .typeinfo [nodename ]
584+ if has_node :
585+ node_type = typeinfo .rust_sum_name
586+ else :
587+ node_type = "()"
588+ self .emit (
589+ f"fn generic_visit_{ typeinfo .sum_name } (&mut self, node: { node_type } ) {{" ,
590+ depth ,
591+ )
592+
593+ def emit_empty_generic_visitor (self , nodename , depth ):
594+ self .emit_generic_visitor_signature (nodename , depth )
543595 self .emit ("}" , depth )
544596
597+ def simple_sum (self , sum , name , depth ):
598+ self .emit_visitor (name , depth )
599+ self .emit_empty_generic_visitor (name , depth )
600+
601+ def visit_match_for_type (self , nodename , rustname , type_ , depth ):
602+ self .emit (f"{ rustname } ::{ type_ .name } " , depth )
603+ if type_ .fields :
604+ self .emit ("(data)" , depth )
605+ data = "data"
606+ else :
607+ data = "()"
608+ self .emit (f"=> self.visit_{ nodename } _{ type_ .name } ({ data } )," , depth )
609+
610+ def visit_sumtype (self , name , type_ , depth ):
611+ self .emit_visitor (type_ .name , depth , has_node = type_ .fields )
612+ self .emit_generic_visitor_signature (type_ .name , depth , has_node = type_ .fields )
613+ for f in type_ .fields :
614+ fieldname = rust_field (f .name )
615+ fieldtype = self .typeinfo .get (f .type )
616+ if not (fieldtype and fieldtype .has_userdata ):
617+ continue
618+
619+ if f .opt :
620+ self .emit (f"if let Some(value) = node.{ fieldname } {{" , depth + 1 )
621+ elif f .seq :
622+ iterable = f"node.{ fieldname } "
623+ if type_ .name == "Dict" and f .name == "keys" :
624+ iterable = f"{ iterable } .into_iter().flatten()"
625+ self .emit (f"for value in { iterable } {{" , depth + 1 )
626+ else :
627+ self .emit ("{" , depth + 1 )
628+ self .emit (f"let value = node.{ fieldname } ;" , depth + 2 )
629+
630+ variable = "value"
631+ if fieldtype .boxed and (not f .seq or f .opt ):
632+ variable = "*" + variable
633+ typeinfo = self .typeinfo [fieldtype .name ]
634+ self .emit (f"self.visit_{ typeinfo .sum_name } ({ variable } );" , depth + 2 )
635+
636+ self .emit ("}" , depth + 1 )
637+
638+ self .emit ("}" , depth )
639+
640+ def sum_with_constructors (self , sum , name , depth ):
641+ if not sum .attributes :
642+ return
643+
644+ rustname = enumname = get_rust_type (name )
645+ if sum .attributes :
646+ rustname = enumname + "Kind"
647+ self .emit_visitor (name , depth )
648+ self .emit_generic_visitor_signature (name , depth )
649+ depth += 1
650+ self .emit ("match node.node {" , depth )
651+ for t in sum .types :
652+ self .visit_match_for_type (name , rustname , t , depth + 1 )
653+ self .emit ("}" , depth )
654+ depth -= 1
655+ self .emit ("}" , depth )
656+
657+ # Now for the visitors for the types
658+ for t in sum .types :
659+ self .visit_sumtype (name , t , depth )
660+
661+ def visitProduct (self , product , name , depth ):
662+ self .emit_visitor (name , depth )
663+ self .emit_empty_generic_visitor (name , depth )
664+
665+
666+ class VisitorModuleVisitor (EmitVisitor ):
667+ def visitModule (self , mod ):
668+ depth = 0
669+ self .emit ("#[allow(unused_variables, non_snake_case)]" , depth )
670+ VisitorTraitDefVisitor (self .file , self .typeinfo ).visit (mod , depth )
671+
545672
546673class ClassDefVisitor (EmitVisitor ):
547674 def visitModule (self , mod ):
@@ -799,23 +926,19 @@ def visit(self, object):
799926 v .emit ("" , 0 )
800927
801928
802- def write_generic_def (mod , typeinfo , f ):
803- f .write (
804- textwrap .dedent (
805- """
806- pub use crate::{Attributed, constant::*};
929+ def write_ast_def (mod , typeinfo , f ):
930+ StructVisitor (f , typeinfo ).visit (mod )
807931
808- type Ident = String;
809- \n
810- """
811- )
812- )
813932
814- c = ChainOfVisitors (StructVisitor (f , typeinfo ), FoldModuleVisitor (f , typeinfo ))
815- c .visit (mod )
933+ def write_fold_def (mod , typeinfo , f ):
934+ FoldModuleVisitor (f , typeinfo ).visit (mod )
935+
936+
937+ def write_visitor_def (mod , typeinfo , f ):
938+ VisitorModuleVisitor (f , typeinfo ).visit (mod )
816939
817940
818- def write_located_def (typeinfo , f ):
941+ def write_located_def (mod , typeinfo , f ):
819942 f .write (
820943 textwrap .dedent (
821944 """
@@ -826,6 +949,8 @@ def write_located_def(typeinfo, f):
826949 )
827950 )
828951 for info in typeinfo .values ():
952+ if info .empty_field :
953+ continue
829954 if info .has_userdata :
830955 generics = "::<SourceRange>"
831956 else :
@@ -863,8 +988,7 @@ def write_ast_mod(mod, typeinfo, f):
863988
864989def main (
865990 input_filename ,
866- generic_filename ,
867- located_filename ,
991+ ast_dir ,
868992 module_filename ,
869993 dump_module = False ,
870994):
@@ -879,34 +1003,34 @@ def main(
8791003 typeinfo = {}
8801004 FindUserdataTypesVisitor (typeinfo ).visit (mod )
8811005
882- with generic_filename .open ("w" ) as generic_file , located_filename .open (
883- "w"
884- ) as located_file :
885- generic_file .write (auto_gen_msg )
886- write_generic_def (mod , typeinfo , generic_file )
887- located_file .write (auto_gen_msg )
888- write_located_def (typeinfo , located_file )
1006+ for filename , write in [
1007+ ("generic" , write_ast_def ),
1008+ ("fold" , write_fold_def ),
1009+ ("located" , write_located_def ),
1010+ ("visitor" , write_visitor_def ),
1011+ ]:
1012+ with (ast_dir / f"{ filename } .rs" ).open ("w" ) as f :
1013+ f .write (auto_gen_msg )
1014+ write (mod , typeinfo , f )
8891015
8901016 with module_filename .open ("w" ) as module_file :
8911017 module_file .write (auto_gen_msg )
8921018 write_ast_mod (mod , typeinfo , module_file )
8931019
894- print (f"{ generic_filename } , { located_filename } , { module_filename } regenerated." )
1020+ print (f"{ ast_dir } , { module_filename } regenerated." )
8951021
8961022
8971023if __name__ == "__main__" :
8981024 parser = ArgumentParser ()
8991025 parser .add_argument ("input_file" , type = Path )
900- parser .add_argument ("-G" , "--generic-file" , type = Path , required = True )
901- parser .add_argument ("-L" , "--located-file" , type = Path , required = True )
1026+ parser .add_argument ("-A" , "--ast-dir" , type = Path , required = True )
9021027 parser .add_argument ("-M" , "--module-file" , type = Path , required = True )
9031028 parser .add_argument ("-d" , "--dump-module" , action = "store_true" )
9041029
9051030 args = parser .parse_args ()
9061031 main (
9071032 args .input_file ,
908- args .generic_file ,
909- args .located_file ,
1033+ args .ast_dir ,
9101034 args .module_file ,
9111035 args .dump_module ,
9121036 )
0 commit comments