@@ -68,6 +68,7 @@ class TypeInfo:
6868 enum_name : Optional [str ]
6969 has_userdata : Optional [bool ]
7070 has_attributes : bool
71+ empty_field : bool
7172 children : set
7273 boxed : bool
7374 product : bool
@@ -78,6 +79,7 @@ def __init__(self, name):
7879 self .enum_name = None
7980 self .has_userdata = None
8081 self .has_attributes = False
82+ self .empty_field = False
8183 self .children = set ()
8284 self .boxed = False
8385 self .product = False
@@ -192,10 +194,9 @@ def visitSum(self, sum, name):
192194 info .has_userdata = False
193195 else :
194196 for t in sum .types :
195- if not t .fields :
196- continue
197197 t_info = TypeInfo (t .name )
198198 t_info .enum_name = name
199+ t_info .empty_field = not t .fields
199200 self .typeinfo [t .name ] = t_info
200201 self .add_children (t .name , t .fields )
201202 if len (sum .types ) > 1 :
@@ -543,111 +544,93 @@ def visitModule(self, mod):
543544 self .emit ("}" , depth )
544545
545546
546- class VisitorStructsDefVisitor (StructVisitor ):
547- def visitModule (self , mod , depth ):
548- for dfn in mod .dfns :
549- self .visit (dfn , depth )
550-
551- def visitProduct (self , product , name , depth ):
552- pass
553-
554- def visitSum (self , sum , name , depth ):
555- if not is_simple (sum ):
556- typeinfo = self .typeinfo [name ]
557- if not sum .attributes :
558- return
559- for t in sum .types :
560- typename = t .name + "Node"
561-
562- has_userdata = any (
563- getattr (self .typeinfo .get (f .type ), "has_userdata" , False )
564- for f in t .fields
565- )
566- self .emit (
567- f"pub struct { typename } Data<{ 'U=()' if has_userdata else '' } > {{" ,
568- depth ,
569- )
570- for f in t .fields :
571- self .visit (f , typeinfo , "pub " , depth + 1 , t .name )
572- self .emit ("}" , depth )
573- self .emit (
574- f"pub type { typename } <U = ()> = Located<{ typename } Data<{ 'U' if has_userdata else '' } >, U>;" ,
575- depth ,
576- )
577- self .emit ("" , depth )
547+ class VisitorTraitDefVisitor (StructVisitor ):
548+ def full_name (self , name ):
549+ typeinfo = self .typeinfo [name ]
550+ if typeinfo .enum_name :
551+ return f"{ typeinfo .enum_name } _{ name } "
552+ else :
553+ return name
578554
555+ def node_type_name (self , name ):
556+ typeinfo = self .typeinfo [name ]
557+ if typeinfo .enum_name :
558+ return f"{ get_rust_type (typeinfo .enum_name )} { get_rust_type (name )} "
559+ else :
560+ return get_rust_type (name )
579561
580- class VisitorTraitDefVisitor (StructVisitor ):
581562 def visitModule (self , mod , depth ):
582563 self .emit ("pub trait Visitor<U=()> {" , depth )
564+
583565 for dfn in mod .dfns :
584566 self .visit (dfn , depth + 1 )
585567 self .emit ("}" , depth )
586568
587569 def visitType (self , type , depth = 0 ):
588570 self .visit (type .value , type .name , depth )
589571
590- def emit_visitor (self , nodename , rusttype , depth ):
591- self .emit (f"fn visit_{ nodename } (&mut self, node: { rusttype } ) {{" , depth )
592- self .emit (f"self.generic_visit_{ nodename } (node);" , depth + 1 )
572+ def emit_visitor (self , nodename , depth , has_node = True ):
573+ typeinfo = self .typeinfo [nodename ]
574+ if has_node :
575+ node_type = typeinfo .rust_sum_name
576+ node_value = "node"
577+ else :
578+ node_type = "()"
579+ node_value = "()"
580+ self .emit (f"fn visit_{ typeinfo .sum_name } (&mut self, node: { node_type } ) {{" , depth )
581+ self .emit (f"self.generic_visit_{ typeinfo .sum_name } ({ node_value } )" , depth + 1 )
593582 self .emit ("}" , depth )
594583
595- def emit_generic_visitor_signature (self , nodename , rusttype , depth ):
596- self .emit (f"fn generic_visit_{ nodename } (&mut self, node: { rusttype } ) {{" , depth )
584+ def emit_generic_visitor_signature (self , nodename , depth , has_node = True ):
585+ typeinfo = self .typeinfo [nodename ]
586+ if has_node :
587+ node_type = typeinfo .rust_sum_name
588+ else :
589+ node_type = "()"
590+ self .emit (f"fn generic_visit_{ typeinfo .sum_name } (&mut self, node: { node_type } ) {{" , depth )
597591
598- def emit_empty_generic_visitor (self , nodename , rusttype , depth ):
599- self .emit_generic_visitor_signature (nodename , rusttype , depth )
592+ def emit_empty_generic_visitor (self , nodename , depth ):
593+ self .emit_generic_visitor_signature (nodename , depth )
600594 self .emit ("}" , depth )
601595
602596 def simple_sum (self , sum , name , depth ):
603- enumname = get_rust_type (name )
604- self .emit_visitor (name , enumname , depth )
605- self .emit_empty_generic_visitor (name , enumname , depth )
597+ self .emit_visitor (name , depth )
598+ self .emit_empty_generic_visitor (name , depth )
606599
607- def visit_match_for_type (self , enumname , rustname , type_ , depth ):
600+ def visit_match_for_type (self , nodename , rustname , type_ , depth ):
608601 self .emit (f"{ rustname } ::{ type_ .name } " , depth )
609602 if type_ .fields :
610- self .emit (f"({ enumname } { type_ .name } {{" , depth )
611- for field in type_ .fields :
612- self .emit (f"{ rust_field (field .name )} ," , depth + 1 )
613- self .emit ("})" , depth )
614- self .emit (f"=> self.visit_{ type_ .name } (" , depth )
615- self .emit (f"{ type_ .name } Node {{" , depth + 2 )
616- self .emit ("location: node.location," , depth + 2 )
617- self .emit ("end_location: node.end_location," , depth + 2 )
618- self .emit ("custom: node.custom," , depth + 2 )
619- self .emit (f"node: { type_ .name } NodeData {{" , depth + 2 )
620- for field in type_ .fields :
621- self .emit (f"{ rust_field (field .name )} ," , depth + 3 )
622- self .emit ("}," , depth + 2 )
623- self .emit ("}" , depth + 1 )
624- self .emit (")," , depth )
603+ self .emit ("(data)" , depth )
604+ data = "data"
605+ else :
606+ data = "()"
607+ self .emit (f"=> self.visit_{ nodename } _{ type_ .name } ({ data } )," , depth )
625608
626- def visit_sumtype (self , type_ , depth ):
627- rustname = get_rust_type (type_ .name ) + "Node"
628- self .emit_visitor (type_ .name , rustname , depth )
629- self .emit_generic_visitor_signature (type_ .name , rustname , depth )
609+ def visit_sumtype (self , name , type_ , depth ):
610+ self .emit_visitor (type_ .name , depth , has_node = type_ .fields )
611+ self .emit_generic_visitor_signature (type_ .name , depth , has_node = type_ .fields )
630612 for f in type_ .fields :
631613 fieldname = rust_field (f .name )
632614 fieldtype = self .typeinfo .get (f .type )
633615 if not (fieldtype and fieldtype .has_userdata ):
634616 continue
635617
636618 if f .opt :
637- self .emit (f"if let Some(value) = node.node. { fieldname } {{" , depth + 1 )
619+ self .emit (f"if let Some(value) = node.{ fieldname } {{" , depth + 1 )
638620 elif f .seq :
639- iterable = f"node.node. { fieldname } "
621+ iterable = f"node.{ fieldname } "
640622 if type_ .name == "Dict" and f .name == "keys" :
641623 iterable = f"{ iterable } .into_iter().flatten()"
642624 self .emit (f"for value in { iterable } {{" , depth + 1 )
643625 else :
644626 self .emit ("{" , depth + 1 )
645- self .emit (f"let value = node.node. { fieldname } ;" , depth + 2 )
627+ self .emit (f"let value = node.{ fieldname } ;" , depth + 2 )
646628
647629 variable = "value"
648630 if fieldtype .boxed and (not f .seq or f .opt ):
649631 variable = "*" + variable
650- self .emit (f"self.visit_{ fieldtype .name } ({ variable } );" , depth + 2 )
632+ typeinfo = self .typeinfo [fieldtype .name ]
633+ self .emit (f"self.visit_{ typeinfo .sum_name } ({ variable } );" , depth + 2 )
651634
652635 self .emit ("}" , depth + 1 )
653636
@@ -660,24 +643,23 @@ def sum_with_constructors(self, sum, name, depth):
660643 rustname = enumname = get_rust_type (name )
661644 if sum .attributes :
662645 rustname = enumname + "Kind"
663- self .emit_visitor (name , enumname , depth )
664- self .emit_generic_visitor_signature (name , enumname , depth )
646+ self .emit_visitor (name , depth )
647+ self .emit_generic_visitor_signature (name , depth )
665648 depth += 1
666649 self .emit ("match node.node {" , depth )
667650 for t in sum .types :
668- self .visit_match_for_type (enumname , rustname , t , depth + 1 )
651+ self .visit_match_for_type (name , rustname , t , depth + 1 )
669652 self .emit ("}" , depth )
670653 depth -= 1
671654 self .emit ("}" , depth )
672655
673656 # Now for the visitors for the types
674657 for t in sum .types :
675- self .visit_sumtype (t , depth )
658+ self .visit_sumtype (name , t , depth )
676659
677660 def visitProduct (self , product , name , depth ):
678- rusttype = get_rust_type (name )
679- self .emit_visitor (name , rusttype , depth )
680- self .emit_empty_generic_visitor (name , rusttype , depth )
661+ self .emit_visitor (name , depth )
662+ self .emit_empty_generic_visitor (name , depth )
681663
682664
683665class VisitorModuleVisitor (EmitVisitor ):
@@ -687,7 +669,6 @@ def visitModule(self, mod):
687669 self .emit ("#[allow(unused_variables, non_snake_case)]" , depth )
688670 self .emit ("pub mod visitor {" , depth )
689671 self .emit ("use super::*;" , depth + 1 )
690- VisitorStructsDefVisitor (self .file , self .typeinfo ).visit (mod , depth + 1 )
691672 VisitorTraitDefVisitor (self .file , self .typeinfo ).visit (mod , depth + 1 )
692673 self .emit ("}" , depth )
693674 self .emit ("" , depth )
@@ -980,6 +961,8 @@ def write_located_def(typeinfo, f):
980961 )
981962 )
982963 for info in typeinfo .values ():
964+ if info .empty_field :
965+ continue
983966 if info .has_userdata :
984967 generics = "::<SourceRange>"
985968 else :
0 commit comments