Skip to content

Commit 947fb53

Browse files
authored
Field accessor and utilities (#20)
* Apply is-macro to Constant and ast nodes
1 parent 2baad9e commit 947fb53

File tree

6 files changed

+117
-18
lines changed

6 files changed

+117
-18
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ anyhow = "1.0.45"
2525
cfg-if = "1.0"
2626
insta = "1.14.0"
2727
itertools = "0.10.3"
28+
is-macro = "0.2.2"
2829
log = "0.4.16"
2930
num-complex = "0.4.0"
3031
num-bigint = "0.4.3"

ast/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ visitor = []
1919
rustpython-parser-core = { workspace = true }
2020
rustpython-literal = { workspace = true, optional = true }
2121

22+
is-macro = { workspace = true }
2223
num-bigint = { workspace = true }

ast/asdl_rs.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import json
88
import textwrap
9+
import re
910

1011
from argparse import ArgumentParser
1112
from pathlib import Path
@@ -16,26 +17,34 @@
1617
TABSIZE = 4
1718
AUTO_GEN_MESSAGE = "// File automatically generated by {}.\n\n"
1819

19-
builtin_type_mapping = {
20+
BUILTIN_TYPE_NAMES = {
2021
"identifier": "Identifier",
2122
"string": "String",
2223
"int": "Int",
2324
"constant": "Constant",
2425
}
25-
assert builtin_type_mapping.keys() == asdl.builtin_types
26+
assert BUILTIN_TYPE_NAMES.keys() == asdl.builtin_types
2627

27-
builtin_int_mapping = {
28+
BUILTIN_INT_NAMES = {
2829
"simple": "bool",
2930
"is_async": "bool",
3031
}
3132

33+
RUST_KEYWORDS = {"if", "while", "for", "return", "match", "try", "await", "yield"}
34+
35+
36+
def rust_field_name(name):
37+
name = rust_type_name(name)
38+
return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
39+
40+
3241
def rust_type_name(name):
3342
"""Return a string for the C name of the type.
3443
3544
This function special cases the default types provided by asdl.
3645
"""
3746
if name in asdl.builtin_types:
38-
builtin = builtin_type_mapping[name]
47+
builtin = BUILTIN_TYPE_NAMES[name]
3948
return builtin
4049
elif name.islower():
4150
return "".join(part.capitalize() for part in name.split("_"))
@@ -271,6 +280,7 @@ def emit_attrs(self, depth):
271280
def simple_sum(self, sum, name, depth):
272281
rust_name = rust_type_name(name)
273282
self.emit_attrs(depth)
283+
self.emit("#[derive(is_macro::Is)]", depth)
274284
self.emit(f"pub enum {rust_name} {{", depth)
275285
for variant in sum.types:
276286
self.emit(f"{variant.name},", depth + 1)
@@ -291,8 +301,15 @@ def sum_with_constructors(self, sum, name, depth):
291301

292302
generics, generics_applied = self.apply_generics(name, "U = ()", "U")
293303
self.emit_attrs(depth)
304+
self.emit("#[derive(is_macro::Is)]", depth)
294305
self.emit(f"pub enum {rust_name}{suffix}{generics} {{", depth)
306+
needs_escape = any(rust_field_name(t.name) in RUST_KEYWORDS for t in sum.types)
295307
for t in sum.types:
308+
if needs_escape:
309+
self.emit(
310+
f'#[is(name = "{rust_field_name(t.name)}_{rust_name.lower()}")]',
311+
depth + 1,
312+
)
296313
if t.fields:
297314
(t_generics_applied,) = self.apply_generics(t.name, "U")
298315
self.emit(
@@ -361,7 +378,7 @@ def visitField(self, field, parent, vis, depth, constructor=None):
361378
if field.seq:
362379
typ = f"Vec<{typ}>"
363380
if typ == "Int":
364-
typ = builtin_int_mapping.get(field.name, typ)
381+
typ = BUILTIN_INT_NAMES.get(field.name, typ)
365382
name = rust_field(field.name)
366383
self.emit(f"{vis}{name}: {typ},", depth)
367384

ast/src/builtin.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ impl std::cmp::PartialEq<usize> for Int {
114114
}
115115
}
116116

117-
#[derive(Clone, Debug, PartialEq)]
117+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
118118
pub enum Constant {
119119
None,
120120
Bool(bool),
@@ -127,6 +127,21 @@ pub enum Constant {
127127
Ellipsis,
128128
}
129129

130+
impl Constant {
131+
pub fn is_true(self) -> bool {
132+
self.bool().map_or(false, |b| b)
133+
}
134+
pub fn is_false(self) -> bool {
135+
self.bool().map_or(false, |b| !b)
136+
}
137+
pub fn complex(self) -> Option<(f64, f64)> {
138+
match self {
139+
Constant::Complex { real, imag } => Some((real, imag)),
140+
_ => None,
141+
}
142+
}
143+
}
144+
130145
impl From<String> for Constant {
131146
fn from(s: String) -> Constant {
132147
Self::Str(s)
@@ -247,3 +262,14 @@ impl<T, U> std::ops::Deref for Attributed<T, U> {
247262
&self.node
248263
}
249264
}
265+
266+
#[cfg(test)]
267+
mod tests {
268+
use super::*;
269+
#[test]
270+
fn test_is_macro() {
271+
let none = Constant::None;
272+
assert!(none.is_none());
273+
assert!(!none.is_bool());
274+
}
275+
}

ast/src/gen/generic.rs

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl<U> From<ModFunctionType<U>> for Mod<U> {
4646
}
4747
}
4848

49-
#[derive(Clone, Debug, PartialEq)]
49+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
5050
pub enum Mod<U = ()> {
5151
Module(ModModule<U>),
5252
Interactive(ModInteractive<U>),
@@ -366,34 +366,61 @@ impl<U> From<StmtExpr<U>> for StmtKind<U> {
366366
}
367367
}
368368

369-
#[derive(Clone, Debug, PartialEq)]
369+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
370370
pub enum StmtKind<U = ()> {
371+
#[is(name = "function_def_stmt")]
371372
FunctionDef(StmtFunctionDef<U>),
373+
#[is(name = "async_function_def_stmt")]
372374
AsyncFunctionDef(StmtAsyncFunctionDef<U>),
375+
#[is(name = "class_def_stmt")]
373376
ClassDef(StmtClassDef<U>),
377+
#[is(name = "return_stmt")]
374378
Return(StmtReturn<U>),
379+
#[is(name = "delete_stmt")]
375380
Delete(StmtDelete<U>),
381+
#[is(name = "assign_stmt")]
376382
Assign(StmtAssign<U>),
383+
#[is(name = "aug_assign_stmt")]
377384
AugAssign(StmtAugAssign<U>),
385+
#[is(name = "ann_assign_stmt")]
378386
AnnAssign(StmtAnnAssign<U>),
387+
#[is(name = "for_stmt")]
379388
For(StmtFor<U>),
389+
#[is(name = "async_for_stmt")]
380390
AsyncFor(StmtAsyncFor<U>),
391+
#[is(name = "while_stmt")]
381392
While(StmtWhile<U>),
393+
#[is(name = "if_stmt")]
382394
If(StmtIf<U>),
395+
#[is(name = "with_stmt")]
383396
With(StmtWith<U>),
397+
#[is(name = "async_with_stmt")]
384398
AsyncWith(StmtAsyncWith<U>),
399+
#[is(name = "match_stmt")]
385400
Match(StmtMatch<U>),
401+
#[is(name = "raise_stmt")]
386402
Raise(StmtRaise<U>),
403+
#[is(name = "try_stmt")]
387404
Try(StmtTry<U>),
405+
#[is(name = "try_star_stmt")]
388406
TryStar(StmtTryStar<U>),
407+
#[is(name = "assert_stmt")]
389408
Assert(StmtAssert<U>),
409+
#[is(name = "import_stmt")]
390410
Import(StmtImport<U>),
411+
#[is(name = "import_from_stmt")]
391412
ImportFrom(StmtImportFrom<U>),
413+
#[is(name = "global_stmt")]
392414
Global(StmtGlobal),
415+
#[is(name = "nonlocal_stmt")]
393416
Nonlocal(StmtNonlocal),
417+
#[is(name = "expr_stmt")]
394418
Expr(StmtExpr<U>),
419+
#[is(name = "pass_stmt")]
395420
Pass,
421+
#[is(name = "break_stmt")]
396422
Break,
423+
#[is(name = "continue_stmt")]
397424
Continue,
398425
}
399426
pub type Stmt<U = ()> = Attributed<StmtKind<U>, U>;
@@ -726,52 +753,79 @@ impl<U> From<ExprSlice<U>> for ExprKind<U> {
726753
}
727754
}
728755

729-
#[derive(Clone, Debug, PartialEq)]
756+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
730757
pub enum ExprKind<U = ()> {
758+
#[is(name = "bool_op_expr")]
731759
BoolOp(ExprBoolOp<U>),
760+
#[is(name = "named_expr_expr")]
732761
NamedExpr(ExprNamedExpr<U>),
762+
#[is(name = "bin_op_expr")]
733763
BinOp(ExprBinOp<U>),
764+
#[is(name = "unary_op_expr")]
734765
UnaryOp(ExprUnaryOp<U>),
766+
#[is(name = "lambda_expr")]
735767
Lambda(ExprLambda<U>),
768+
#[is(name = "if_exp_expr")]
736769
IfExp(ExprIfExp<U>),
770+
#[is(name = "dict_expr")]
737771
Dict(ExprDict<U>),
772+
#[is(name = "set_expr")]
738773
Set(ExprSet<U>),
774+
#[is(name = "list_comp_expr")]
739775
ListComp(ExprListComp<U>),
776+
#[is(name = "set_comp_expr")]
740777
SetComp(ExprSetComp<U>),
778+
#[is(name = "dict_comp_expr")]
741779
DictComp(ExprDictComp<U>),
780+
#[is(name = "generator_exp_expr")]
742781
GeneratorExp(ExprGeneratorExp<U>),
782+
#[is(name = "await_expr")]
743783
Await(ExprAwait<U>),
784+
#[is(name = "yield_expr")]
744785
Yield(ExprYield<U>),
786+
#[is(name = "yield_from_expr")]
745787
YieldFrom(ExprYieldFrom<U>),
788+
#[is(name = "compare_expr")]
746789
Compare(ExprCompare<U>),
790+
#[is(name = "call_expr")]
747791
Call(ExprCall<U>),
792+
#[is(name = "formatted_value_expr")]
748793
FormattedValue(ExprFormattedValue<U>),
794+
#[is(name = "joined_str_expr")]
749795
JoinedStr(ExprJoinedStr<U>),
796+
#[is(name = "constant_expr")]
750797
Constant(ExprConstant),
798+
#[is(name = "attribute_expr")]
751799
Attribute(ExprAttribute<U>),
800+
#[is(name = "subscript_expr")]
752801
Subscript(ExprSubscript<U>),
802+
#[is(name = "starred_expr")]
753803
Starred(ExprStarred<U>),
804+
#[is(name = "name_expr")]
754805
Name(ExprName),
806+
#[is(name = "list_expr")]
755807
List(ExprList<U>),
808+
#[is(name = "tuple_expr")]
756809
Tuple(ExprTuple<U>),
810+
#[is(name = "slice_expr")]
757811
Slice(ExprSlice<U>),
758812
}
759813
pub type Expr<U = ()> = Attributed<ExprKind<U>, U>;
760814

761-
#[derive(Clone, Debug, PartialEq)]
815+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
762816
pub enum ExprContext {
763817
Load,
764818
Store,
765819
Del,
766820
}
767821

768-
#[derive(Clone, Debug, PartialEq)]
822+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
769823
pub enum Boolop {
770824
And,
771825
Or,
772826
}
773827

774-
#[derive(Clone, Debug, PartialEq)]
828+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
775829
pub enum Operator {
776830
Add,
777831
Sub,
@@ -788,15 +842,15 @@ pub enum Operator {
788842
FloorDiv,
789843
}
790844

791-
#[derive(Clone, Debug, PartialEq)]
845+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
792846
pub enum Unaryop {
793847
Invert,
794848
Not,
795849
UAdd,
796850
USub,
797851
}
798852

799-
#[derive(Clone, Debug, PartialEq)]
853+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
800854
pub enum Cmpop {
801855
Eq,
802856
NotEq,
@@ -831,7 +885,7 @@ impl<U> From<ExcepthandlerExceptHandler<U>> for ExcepthandlerKind<U> {
831885
}
832886
}
833887

834-
#[derive(Clone, Debug, PartialEq)]
888+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
835889
pub enum ExcepthandlerKind<U = ()> {
836890
ExceptHandler(ExcepthandlerExceptHandler<U>),
837891
}
@@ -977,7 +1031,7 @@ impl<U> From<PatternMatchOr<U>> for PatternKind<U> {
9771031
}
9781032
}
9791033

980-
#[derive(Clone, Debug, PartialEq)]
1034+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
9811035
pub enum PatternKind<U = ()> {
9821036
MatchValue(PatternMatchValue<U>),
9831037
MatchSingleton(PatternMatchSingleton),
@@ -1002,7 +1056,7 @@ impl From<TypeIgnoreTypeIgnore> for TypeIgnore {
10021056
}
10031057
}
10041058

1005-
#[derive(Clone, Debug, PartialEq)]
1059+
#[derive(Clone, Debug, PartialEq, is_macro::Is)]
10061060
pub enum TypeIgnore {
10071061
TypeIgnore(TypeIgnoreTypeIgnore),
10081062
}

ast/src/impls.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{Constant, ExprKind};
22

33
impl<U> ExprKind<U> {
44
/// Returns a short name for the node suitable for use in error messages.
5-
pub fn name(&self) -> &'static str {
5+
pub fn python_name(&self) -> &'static str {
66
match self {
77
ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => {
88
"operator"

0 commit comments

Comments
 (0)