@@ -496,15 +496,29 @@ def visitModule(self, mod, depth):
496496 self.emit("pub trait Fold<U> {", depth)
497497 self.emit("type TargetU;", depth + 1)
498498 self.emit("type Error;", depth + 1)
499+ self.emit("type UserContext;", depth + 1)
499500 self.emit(
500501 """
501- fn map_user (&mut self, user: U) -> Result< Self::TargetU, Self::Error> ;
502+ fn will_map_user (&mut self, user: & U) -> Self::UserContext ;
502503 #[cfg(feature = "all-nodes-with-ranges")]
503- fn map_user_cfg (&mut self, user: U) -> Result< Self::TargetU, Self::Error> {
504- self.map_user (user)
504+ fn will_map_user_cfg (&mut self, user: & U) -> Self::UserContext {
505+ self.will_map_user (user)
505506 }
506507 #[cfg(not(feature = "all-nodes-with-ranges"))]
507- fn map_user_cfg(&mut self, _user: crate::EmptyRange<U>) -> Result<crate::EmptyRange<Self::TargetU>, Self::Error> {
508+ fn will_map_user_cfg(&mut self, user: &crate::EmptyRange<U>) -> crate::EmptyRange<Self::TargetU> {
509+ crate::EmptyRange::default()
510+ }
511+ fn map_user(&mut self, user: U, context: Self::UserContext) -> Result<Self::TargetU, Self::Error>;
512+ #[cfg(feature = "all-nodes-with-ranges")]
513+ fn map_user_cfg(&mut self, user: U, context: Self::UserContext) -> Result<Self::TargetU, Self::Error> {
514+ self.map_user(user, context)
515+ }
516+ #[cfg(not(feature = "all-nodes-with-ranges"))]
517+ fn map_user_cfg(
518+ &mut self,
519+ _user: crate::EmptyRange<U>,
520+ _context: crate::EmptyRange<Self::TargetU>,
521+ ) -> Result<crate::EmptyRange<Self::TargetU>, Self::Error> {
508522 Ok(crate::EmptyRange::default())
509523 }
510524 """,
@@ -532,17 +546,32 @@ def visitType(self, type, depth):
532546 self.emit(f"fold_{name}(self, node)", depth + 1)
533547 self.emit("}", depth)
534548
549+ if isinstance(type.value, asdl.Sum) and not is_simple(type.value):
550+ for cons in type.value.types:
551+ self.visit(cons, type, depth)
552+
553+ def visitConstructor(self, cons, type, depth):
554+ apply_u, apply_target_u = self.apply_generics(type.name, "U", "Self::TargetU")
555+ enum_name = rust_type_name(type.name)
556+ func_name = f"fold_{type.name}_{rust_field_name(cons.name)}"
557+ self.emit(
558+ f"fn {func_name}(&mut self, node: {enum_name}{cons.name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, Self::Error> {{",
559+ depth,
560+ )
561+ self.emit(f"{func_name}(self, node)", depth + 1)
562+ self.emit("}", depth)
563+
535564
536565class FoldImplVisitor(EmitVisitor):
537566 def visitModule(self, mod, depth):
538567 for dfn in mod.dfns:
539568 self.visit(dfn, depth)
540569
541570 def visitType(self, type, depth=0):
542- self.visit(type.value, type.name , depth)
571+ self.visit(type.value, type, depth)
543572
544- def visitSum(self, sum, name , depth):
545- type_info = self.type_info[ name]
573+ def visitSum(self, sum, type , depth):
574+ name = type. name
546575 apply_t, apply_u, apply_target_u = self.apply_generics(
547576 name, "T", "U", "F::TargetU"
548577 )
@@ -568,27 +597,69 @@ def visitSum(self, sum, name, depth):
568597 self.emit("Ok(node) }", depth + 1)
569598 return
570599
571- self.emit("match node {", depth + 1)
600+ self.emit("let folded = match node {", depth + 1)
572601 for cons in sum.types:
573- fields_pattern = self.make_pattern(enum_name, cons.name, cons.fields)
574602 self.emit(
575- f"{fields_pattern[0]} {{ {fields_pattern[1]}}} {fields_pattern[2]} => {{ ",
576- depth + 2 ,
603+ f"{enum_name}::{cons.name}(cons) => {enum_name}::{cons.name}(Foldable::fold(cons, folder)?), ",
604+ depth + 1 ,
577605 )
578606
579- map_user_suffix = "" if type_info.has_attributes else "_cfg"
580- self.emit(
581- f"let range = folder.map_user{map_user_suffix}(range)?;", depth + 3
582- )
607+ self.emit("};", depth + 1)
608+ self.emit("Ok(folded)", depth + 1)
609+ self.emit("}", depth)
583610
584- self.gen_construction(
585- fields_pattern[0], cons.fields, fields_pattern[2], depth + 3
586- )
587- self.emit("}", depth + 2)
611+ for cons in sum.types:
612+ self.visit(cons, type, depth)
613+
614+ def visitConstructor(self, cons, type, depth):
615+ apply_t, apply_u, apply_target_u = self.apply_generics(
616+ type.name, "T", "U", "F::TargetU"
617+ )
618+ enum_name = rust_type_name(type.name)
619+
620+ cons_type_name = f"{enum_name}{cons.name}"
621+
622+ self.emit(
623+ f"impl<T, U> Foldable<T, U> for {cons_type_name}{apply_t} {{", depth
624+ )
625+ self.emit(f"type Mapped = {cons_type_name}{apply_u};", depth + 1)
626+ self.emit(
627+ "fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
628+ depth + 1,
629+ )
630+ self.emit(
631+ f"folder.fold_{type.name}_{rust_field_name(cons.name)}(self)", depth + 2
632+ )
588633 self.emit("}", depth + 1)
589634 self.emit("}", depth)
590635
591- def visitProduct(self, product, name, depth):
636+ self.emit(
637+ f"pub fn fold_{type.name}_{rust_field_name(cons.name)}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {cons_type_name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, F::Error> {{",
638+ depth,
639+ )
640+
641+ type_info = self.type_info[type.name]
642+
643+ fields_pattern = self.make_pattern(cons.fields)
644+
645+ map_user_suffix = "" if type_info.has_attributes else "_cfg"
646+ self.emit(
647+ f"""
648+ let {cons_type_name} {{ {fields_pattern} }} = node;
649+ let context = folder.will_map_user{map_user_suffix}(&range);
650+ """,
651+ depth + 3,
652+ )
653+ self.fold_fields(cons.fields, depth + 3)
654+ self.emit(
655+ f"let range = folder.map_user{map_user_suffix}(range, context)?;",
656+ depth + 3,
657+ )
658+ self.composite_fields(f"{cons_type_name}", cons.fields, depth + 3)
659+ self.emit("}", depth + 2)
660+
661+ def visitProduct(self, product, type, depth):
662+ name = type.name
592663 apply_t, apply_u, apply_target_u = self.apply_generics(
593664 name, "T", "U", "F::TargetU"
594665 )
@@ -610,41 +681,47 @@ def visitProduct(self, product, name, depth):
610681 depth,
611682 )
612683
613- fields_pattern = self.make_pattern(struct_name, struct_name, product.fields)
614- self.emit(f"let {struct_name} {{ {fields_pattern[1] } }} = node;", depth + 1)
684+ fields_pattern = self.make_pattern(product.fields)
685+ self.emit(f"let {struct_name} {{ {fields_pattern} }} = node;", depth + 1)
615686
616687 map_user_suffix = "" if has_attributes else "_cfg"
617- self.emit(f"let range = folder.map_user{map_user_suffix}(range)?;", depth + 3)
618688
619- self.gen_construction(struct_name, product.fields, "", depth + 1)
689+ self.emit(
690+ f"let context = folder.will_map_user{map_user_suffix}(&range);", depth + 3
691+ )
692+ self.fold_fields(product.fields, depth + 1)
693+ self.emit(
694+ f"let range = folder.map_user{map_user_suffix}(range, context)?;", depth + 3
695+ )
696+ self.composite_fields(struct_name, product.fields, depth + 1)
620697
621698 self.emit("}", depth)
622699
623- def make_pattern(self, rust_name, fieldname: str, fields):
624- header = f"{rust_name}::{fieldname}({rust_name}{fieldname}"
625- footer = ")"
626-
700+ def make_pattern(self, fields):
627701 body = ",".join(rust_field(f.name) for f in fields)
628702 if body:
629703 body += ","
630704 body += "range"
631705
632- return header, body, footer
706+ return body
633707
634- def gen_construction(self, header, fields, footer, depth):
708+ def fold_fields(self, fields, depth):
709+ for field in fields:
710+ name = rust_field(field.name)
711+ self.emit(f"let {name} = Foldable::fold({name}, folder)?;", depth + 1)
712+
713+ def composite_fields(self, header, fields, depth):
635714 self.emit(f"Ok({header} {{", depth)
636715 for field in fields:
637716 name = rust_field(field.name)
638- self.emit(f"{name}: Foldable::fold({name}, folder)? ,", depth + 1)
717+ self.emit(f"{name},", depth + 1)
639718 self.emit("range,", depth + 1)
640-
641- self.emit(f"}}{footer})", depth)
719+ self.emit(f"}})", depth)
642720
643721
644722class FoldModuleVisitor(EmitVisitor):
645723 def visitModule(self, mod):
646724 depth = 0
647- self.emit("use crate::fold_helpers::Foldable;", depth)
648725 FoldTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
649726 FoldImplVisitor(self.file, self.type_info).visit(mod, depth)
650727
0 commit comments