@@ -131,6 +131,9 @@ def visitSum(self, sum, name):
131131 if is_simple (sum ):
132132 info .has_userdata = False
133133 else :
134+ for t in sum .types :
135+ self .typeinfo [t .name ] = TypeInfo (t .name )
136+ self .add_children (t .name , t .fields )
134137 if len (sum .types ) > 1 :
135138 info .boxed = True
136139 if sum .attributes :
@@ -205,16 +208,49 @@ def simple_sum(self, sum, name, depth):
205208
206209 def sum_with_constructors (self , sum , name , depth ):
207210 typeinfo = self .typeinfo [name ]
208- generics , generics_applied = self .get_generics (name , "U = ()" , "U" )
209211 enumname = rustname = get_rust_type (name )
210212 # all the attributes right now are for location, so if it has attrs we
211213 # can just wrap it in Located<>
212214 if sum .attributes :
213215 enumname = rustname + "Kind"
216+
217+ for t in sum .types :
218+ if not t .fields :
219+ continue
220+ self .emit_attrs (depth )
221+ self .typeinfo [t ] = TypeInfo (t )
222+ t_generics , t_generics_applied = self .get_generics (t .name , "U = ()" , "U" )
223+ payload_name = f"{ rustname } { t .name } "
224+ self .emit (f"pub struct { payload_name } { t_generics } {{" , depth )
225+ for f in t .fields :
226+ self .visit (f , typeinfo , "pub " , depth + 1 , t .name )
227+ self .emit ("}" , depth )
228+ self .emit (
229+ textwrap .dedent (
230+ f"""
231+ impl{ t_generics_applied } From<{ payload_name } { t_generics_applied } > for { enumname } { t_generics_applied } {{
232+ fn from(payload: { payload_name } { t_generics_applied } ) -> Self {{
233+ { enumname } ::{ t .name } (payload)
234+ }}
235+ }}
236+ """
237+ ),
238+ depth ,
239+ )
240+
241+ generics , generics_applied = self .get_generics (name , "U = ()" , "U" )
214242 self .emit_attrs (depth )
215243 self .emit (f"pub enum { enumname } { generics } {{" , depth )
216244 for t in sum .types :
217- self .visit (t , typeinfo , depth + 1 )
245+ if t .fields :
246+ t_generics , t_generics_applied = self .get_generics (
247+ t .name , "U = ()" , "U"
248+ )
249+ self .emit (
250+ f"{ t .name } ({ rustname } { t .name } { t_generics_applied } )," , depth + 1
251+ )
252+ else :
253+ self .emit (f"{ t .name } ," , depth + 1 )
218254 self .emit ("}" , depth )
219255 if sum .attributes :
220256 self .emit (
@@ -238,13 +274,18 @@ def visitField(self, field, parent, vis, depth, constructor=None):
238274 if fieldtype and fieldtype .has_userdata :
239275 typ = f"{ typ } <U>"
240276 # don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
241- if fieldtype and fieldtype .boxed and (not (parent .product or field .seq ) or field .opt ):
277+ if (
278+ fieldtype
279+ and fieldtype .boxed
280+ and (not (parent .product or field .seq ) or field .opt )
281+ ):
242282 typ = f"Box<{ typ } >"
243283 if field .opt or (
244284 # When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
245285 # the expression to be unpacked goes in `values` with a `None` at the corresponding
246286 # position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
247- constructor == "Dict" and field .name == "keys"
287+ constructor == "Dict"
288+ and field .name == "keys"
248289 ):
249290 typ = f"Option<{ typ } >"
250291 if field .seq :
@@ -344,14 +385,21 @@ def visitSum(self, sum, name, depth):
344385 )
345386 if is_located :
346387 self .emit ("fold_located(folder, node, |folder, node| {" , depth )
347- enumname += "Kind"
388+ rustname = enumname + "Kind"
389+ else :
390+ rustname = enumname
348391 self .emit ("match node {" , depth + 1 )
349392 for cons in sum .types :
350- fields_pattern = self .make_pattern (cons .fields )
393+ fields_pattern = self .make_pattern (
394+ enumname , rustname , cons .name , cons .fields
395+ )
351396 self .emit (
352- f"{ enumname } ::{ cons .name } {{ { fields_pattern } }} => {{" , depth + 2
397+ f"{ fields_pattern [0 ]} {{ { fields_pattern [1 ]} }} { fields_pattern [2 ]} => {{" ,
398+ depth + 2 ,
399+ )
400+ self .gen_construction (
401+ fields_pattern [0 ], cons .fields , fields_pattern [2 ], depth + 3
353402 )
354- self .gen_construction (f"{ enumname } ::{ cons .name } " , cons .fields , depth + 3 )
355403 self .emit ("}" , depth + 2 )
356404 self .emit ("}" , depth + 1 )
357405 if is_located :
@@ -381,23 +429,33 @@ def visitProduct(self, product, name, depth):
381429 )
382430 if is_located :
383431 self .emit ("fold_located(folder, node, |folder, node| {" , depth )
384- structname += "Data"
385- fields_pattern = self .make_pattern (product .fields )
386- self .emit (f"let { structname } {{ { fields_pattern } }} = node;" , depth + 1 )
387- self .gen_construction (structname , product .fields , depth + 1 )
432+ rustname = structname + "Data"
433+ else :
434+ rustname = structname
435+ fields_pattern = self .make_pattern (rustname , structname , None , product .fields )
436+ self .emit (f"let { rustname } {{ { fields_pattern [1 ]} }} = node;" , depth + 1 )
437+ self .gen_construction (rustname , product .fields , "" , depth + 1 )
388438 if is_located :
389439 self .emit ("})" , depth )
390440 self .emit ("}" , depth )
391441
392- def make_pattern (self , fields ):
393- return "," .join (rust_field (f .name ) for f in fields )
442+ def make_pattern (self , rustname , pyname , fieldname , fields ):
443+ if fields :
444+ header = f"{ pyname } ::{ fieldname } ({ rustname } { fieldname } "
445+ footer = ")"
446+ else :
447+ header = f"{ pyname } ::{ fieldname } "
448+ footer = ""
394449
395- def gen_construction (self , cons_path , fields , depth ):
396- self .emit (f"Ok({ cons_path } {{" , depth )
450+ body = "," .join (rust_field (f .name ) for f in fields )
451+ return header , body , footer
452+
453+ def gen_construction (self , header , fields , footer , depth ):
454+ self .emit (f"Ok({ header } {{" , depth )
397455 for field in fields :
398456 name = rust_field (field .name )
399457 self .emit (f"{ name } : Foldable::fold({ name } , folder)?," , depth + 1 )
400- self .emit (" })" , depth )
458+ self .emit (f"}} { footer } )" , depth )
401459
402460
403461class FoldModuleVisitor (TypeInfoEmitVisitor ):
@@ -514,33 +572,36 @@ def visitType(self, type, depth=0):
514572 self .visit (type .value , type .name , depth )
515573
516574 def visitSum (self , sum , name , depth ):
517- enumname = get_rust_type (name )
575+ rustname = enumname = get_rust_type (name )
518576 if sum .attributes :
519- enumname += "Kind"
577+ rustname = enumname + "Kind"
520578
521- self .emit (f"impl NamedNode for ast::{ enumname } {{" , depth )
579+ self .emit (f"impl NamedNode for ast::{ rustname } {{" , depth )
522580 self .emit (f"const NAME: &'static str = { json .dumps (name )} ;" , depth + 1 )
523581 self .emit ("}" , depth )
524- self .emit (f"impl Node for ast::{ enumname } {{" , depth )
582+ self .emit (f"impl Node for ast::{ rustname } {{" , depth )
525583 self .emit (
526584 "fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {" , depth + 1
527585 )
528586 self .emit ("match self {" , depth + 2 )
529587 for variant in sum .types :
530- self .constructor_to_object (variant , enumname , depth + 3 )
588+ self .constructor_to_object (variant , enumname , rustname , depth + 3 )
531589 self .emit ("}" , depth + 2 )
532590 self .emit ("}" , depth + 1 )
533591 self .emit (
534592 "fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {" ,
535593 depth + 1 ,
536594 )
537- self .gen_sum_fromobj (sum , name , enumname , depth + 2 )
595+ self .gen_sum_fromobj (sum , name , enumname , rustname , depth + 2 )
538596 self .emit ("}" , depth + 1 )
539597 self .emit ("}" , depth )
540598
541- def constructor_to_object (self , cons , enumname , depth ):
542- fields_pattern = self .make_pattern (cons .fields )
543- self .emit (f"ast::{ enumname } ::{ cons .name } {{ { fields_pattern } }} => {{" , depth )
599+ def constructor_to_object (self , cons , enumname , rustname , depth ):
600+ self .emit (f"ast::{ rustname } ::{ cons .name } " , depth )
601+ if cons .fields :
602+ fields_pattern = self .make_pattern (cons .fields )
603+ self .emit (f"( ast::{ enumname } { cons .name } {{ { fields_pattern } }} )" , depth )
604+ self .emit (" => {" , depth )
544605 self .make_node (cons .name , cons .fields , depth + 1 )
545606 self .emit ("}" , depth )
546607
@@ -586,15 +647,20 @@ def make_node(self, variant, fields, depth):
586647 def make_pattern (self , fields ):
587648 return "," .join (rust_field (f .name ) for f in fields )
588649
589- def gen_sum_fromobj (self , sum , sumname , enumname , depth ):
650+ def gen_sum_fromobj (self , sum , sumname , enumname , rustname , depth ):
590651 if sum .attributes :
591652 self .extract_location (sumname , depth )
592653
593654 self .emit ("let _cls = _object.class();" , depth )
594655 self .emit ("Ok(" , depth )
595656 for cons in sum .types :
596657 self .emit (f"if _cls.is(Node{ cons .name } ::static_type()) {{" , depth )
597- self .gen_construction (f"{ enumname } ::{ cons .name } " , cons , sumname , depth + 1 )
658+ if cons .fields :
659+ self .emit (f"ast::{ rustname } ::{ cons .name } (ast::{ enumname } { cons .name } {{" , depth + 1 )
660+ self .gen_construction_fields (cons , sumname , depth + 1 )
661+ self .emit ("})" , depth + 1 )
662+ else :
663+ self .emit (f"ast::{ rustname } ::{ cons .name } " , depth + 1 )
598664 self .emit ("} else" , depth )
599665
600666 self .emit ("{" , depth )
@@ -610,13 +676,16 @@ def gen_product_fromobj(self, product, prodname, structname, depth):
610676 self .gen_construction (structname , product , prodname , depth + 1 )
611677 self .emit (")" , depth )
612678
613- def gen_construction (self , cons_path , cons , name , depth ):
614- self .emit (f"ast::{ cons_path } {{" , depth )
679+ def gen_construction_fields (self , cons , name , depth ):
615680 for field in cons .fields :
616681 self .emit (
617682 f"{ rust_field (field .name )} : { self .decode_field (field , name )} ," ,
618683 depth + 1 ,
619684 )
685+
686+ def gen_construction (self , cons_path , cons , name , depth ):
687+ self .emit (f"ast::{ cons_path } {{" , depth )
688+ self .gen_construction_fields (cons , name , depth + 1 )
620689 self .emit ("}" , depth )
621690
622691 def extract_location (self , typename , depth ):
0 commit comments