@@ -937,6 +937,111 @@ def emit_located_impl(self, info):
937937 )
938938
939939
940+ class ToPyo3AstVisitor (EmitVisitor ):
941+ """Visitor to generate type-defs for AST."""
942+
943+ def __init__ (self , namespace , * args , ** kw ):
944+ super ().__init__ (* args , ** kw )
945+ self .namespace = namespace
946+
947+ @property
948+ def generics (self ):
949+ if self .namespace == "ranged" :
950+ return "<TextRange>"
951+ elif self .namespace == "located" :
952+ return "<SourceRange>"
953+ else :
954+ assert False , self .namespace
955+
956+ def visitModule (self , mod ):
957+ for dfn in mod .dfns :
958+ self .visit (dfn )
959+
960+ def visitType (self , type , depth = 0 ):
961+ self .visit (type .value , type .name , depth )
962+
963+ def visitProduct (self , product , name , depth = 0 ):
964+ rust_name = rust_type_name (name )
965+ self .emit_to_pyo3_with_fields (product , rust_name )
966+
967+ def visitSum (self , sum , name , depth = 0 ):
968+ rust_name = rust_type_name (name )
969+ simple = is_simple (sum )
970+ if is_simple (sum ):
971+ return
972+
973+ self .emit (
974+ f"""
975+ impl ToPyo3Ast for crate::generic::{ rust_name } { self .generics } {{
976+ #[inline]
977+ fn to_pyo3_ast(&self, { "_" if simple else "" } py: Python) -> PyResult<Py<PyAny>> {{
978+ let instance = match &self {{
979+ """ ,
980+ 0 ,
981+ )
982+ for cons in sum .types :
983+ self .emit (
984+ f"""crate::{ rust_name } ::{ cons .name } (cons) => cons.to_pyo3_ast(py)?,""" ,
985+ depth ,
986+ )
987+ self .emit (
988+ """
989+ };
990+ Ok(instance)
991+ }
992+ }
993+ """ ,
994+ 0 ,
995+ )
996+
997+ for cons in sum .types :
998+ self .visit (cons , rust_name , depth )
999+
1000+ def visitConstructor (self , cons , parent , depth ):
1001+ self .emit_to_pyo3_with_fields (cons , f"{ parent } { cons .name } " )
1002+
1003+ def emit_to_pyo3_with_fields (self , cons , name ):
1004+ if cons .fields :
1005+ self .emit (
1006+ f"""
1007+ impl ToPyo3Ast for crate::{ name } { self .generics } {{
1008+ #[inline]
1009+ fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
1010+ let cache = Self::py_type_cache().get().unwrap();
1011+ let instance = cache.0.call1(py, (
1012+ """ ,
1013+ 0 ,
1014+ )
1015+ for field in cons .fields :
1016+ self .emit (
1017+ f"self.{ rust_field (field .name )} .to_pyo3_ast(py)?," ,
1018+ 3 ,
1019+ )
1020+ self .emit (
1021+ """
1022+ ))?;
1023+ Ok(instance)
1024+ }
1025+ }
1026+ """ ,
1027+ 0 ,
1028+ )
1029+ else :
1030+ self .emit (
1031+ f"""
1032+ impl ToPyo3Ast for crate::{ name } { self .generics } {{
1033+ #[inline]
1034+ fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
1035+ let cache = Self::py_type_cache().get().unwrap();
1036+ let instance = cache.0.call0(py)?;
1037+ Ok(instance)
1038+ }}
1039+ }}
1040+ """ ,
1041+ 0 ,
1042+ )
1043+
1044+
9401045class StdlibClassDefVisitor (EmitVisitor ):
9411046 def visitModule (self , mod ):
9421047 for dfn in mod .dfns :
@@ -1271,6 +1376,82 @@ def write_located_def(mod, type_info, f):
12711376 LocatedDefVisitor (f , type_info ).visit (mod )
12721377
12731378
1379+ def write_pyo3_node (type_info , f ):
1380+ def write (info : TypeInfo ):
1381+ rust_name = info .rust_sum_name
1382+ if info .is_simple :
1383+ generics = ""
1384+ else :
1385+ generics = "<R>"
1386+
1387+ f .write (
1388+ textwrap .dedent (
1389+ f"""
1390+ impl{ generics } Pyo3Node for crate::generic::{ rust_name } { generics } {{
1391+ #[inline]
1392+ fn py_type_cache() -> &'static OnceCell<(Py<PyAny>, Py<PyAny>)> {{
1393+ static PY_TYPE: OnceCell<(Py<PyAny>, Py<PyAny>)> = OnceCell::new();
1394+ &PY_TYPE
1395+ }}
1396+ }}
1397+ """
1398+ ),
1399+ )
1400+
1401+ for info in type_info .values ():
1402+ write (info )
1403+
1404+
1405+ def write_to_pyo3 (mod , type_info , f ):
1406+ write_pyo3_node (type_info , f )
1407+ write_to_pyo3_simple (type_info , f )
1408+
1409+ for namespace in ("ranged" , "located" ):
1410+ ToPyo3AstVisitor (namespace , f , type_info ).visit (mod )
1411+
1412+ f .write (
1413+ """
1414+ pub fn init(py: Python) -> PyResult<()> {
1415+ let ast_module = PyModule::import(py, "_ast")?;
1416+ """
1417+ )
1418+
1419+ for info in type_info .values ():
1420+ rust_name = info .rust_sum_name
1421+ f .write (f"cache_py_type::<crate::generic::{ rust_name } >(ast_module)?;\n " )
1422+ f .write ("Ok(())\n }" )
1423+
1424+
1425+ def write_to_pyo3_simple (type_info , f ):
1426+ for type_info in type_info .values ():
1427+ if not type_info .is_sum :
1428+ continue
1429+ if not type_info .is_simple :
1430+ continue
1431+
1432+ rust_name = type_info .rust_sum_name
1433+ f .write (
1434+ f"""
1435+ impl ToPyo3Ast for crate::generic::{ rust_name } {{
1436+ #[inline]
1437+ fn to_pyo3_ast(&self, _py: Python) -> PyResult<Py<PyAny>> {{
1438+ let cell = match &self {{
1439+ """ ,
1440+ )
1441+ for cons in type_info .type .value .types :
1442+ f .write (
1443+ f"""crate::{ rust_name } ::{ cons .name } => crate::{ rust_name } { cons .name } ::py_type_cache(),""" ,
1444+ )
1445+ f .write (
1446+ """
1447+ };
1448+ Ok(cell.get().unwrap().1.clone())
1449+ }
1450+ }
1451+ """ ,
1452+ )
1453+
1454+
12741455def write_ast_mod (mod , type_info , f ):
12751456 f .write (
12761457 textwrap .dedent (
@@ -1316,6 +1497,7 @@ def main(
13161497 ("ranged" , p (write_ranged_def , mod , type_info )),
13171498 ("located" , p (write_located_def , mod , type_info )),
13181499 ("visitor" , p (write_visitor_def , mod , type_info )),
1500+ ("to_pyo3" , p (write_to_pyo3 , mod , type_info )),
13191501 ]:
13201502 with (ast_dir / f"{ filename } .rs" ).open ("w" ) as f :
13211503 f .write (auto_gen_msg )
0 commit comments