Skip to content
This repository was archived by the owner on Oct 26, 2025. It is now read-only.

Commit 6950ec8

Browse files
committed
Support function argument parameterization
1 parent 49089ab commit 6950ec8

File tree

6 files changed

+209
-59
lines changed

6 files changed

+209
-59
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22

33
All notable changes to this project will be documented in this file
44

5+
## Unreleased
6+
7+
### Added
8+
9+
- added support for parameterizing function arguments
10+
11+
### Changed
12+
13+
- updated to rust 2021 edition
14+
15+
516
## 0.2.2 - 2024-05-06
617

718
### Fixed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ quote = "1"
1717
proc-macro2 = "1.0.43"
1818
syn = { version = "2.0.17", features = ["full", "extra-traits"] }
1919
itertools = "0.13"
20+
strum = { version = "0.26.2", features = ["derive"] }
2021

src/arguments.rs

Lines changed: 101 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,22 @@
77
use crate::extract::{Extract, ExtractIterator, ExtractMap};
88
use itertools::Itertools;
99
use proc_macro2::{Ident, Span};
10+
use std::fmt::Debug;
1011
use syn::parse::discouraged::Speculative;
1112
use syn::parse::{Parse, ParseStream};
1213
use syn::punctuated::Punctuated;
1314
use syn::spanned::Spanned;
1415
use syn::token::Comma;
15-
use syn::{Expr, GenericParam, Lit, LitStr, Type};
16+
use syn::{Expr, ExprLit, FnArg, GenericParam, Lit, LitStr, Pat, Type};
1617

17-
use crate::params::Param;
18+
use crate::params::{Param, ParamType};
1819

1920
/// The value of an [`Argument`]; everything after the equal sign
2021
#[derive(Clone, Debug)]
2122
pub(crate) enum ArgumentValue {
2223
TypeList(Vec<Type>),
2324
LitList(Vec<Lit>),
25+
ExprList(Vec<Expr>),
2426
Str(String),
2527
}
2628

@@ -45,20 +47,31 @@ fn parse_typelist(input: ParseStream) -> Option<syn::Result<ArgumentValue>> {
4547

4648
/// Parse a bracketed list of literals
4749
fn parse_litlist(input: ParseStream) -> Option<syn::Result<ArgumentValue>> {
48-
let parse = || {
49-
let exprs = input.parse::<syn::ExprArray>()?;
50-
let entries: syn::Result<Vec<Lit>> = exprs
50+
// match on brackets. anything invalid after is an error
51+
return if input.peek(syn::token::Bracket) {
52+
let exprs = input.parse::<syn::ExprArray>().ok()?;
53+
let entries: Option<Vec<Lit>> = exprs
5154
.elems
5255
.iter()
53-
.map(|expr: &Expr| -> syn::Result<Lit> {
54-
return if let Expr::Lit(lit) = expr {
55-
Ok(lit.lit.clone())
56+
.map(|expr: &Expr| -> Option<Lit> {
57+
if let Expr::Lit(lit) = expr {
58+
Some(lit.lit.clone())
5659
} else {
57-
Err(syn::Error::new(expr.span(), "Expression is not a literal"))
58-
};
60+
None
61+
}
5962
})
6063
.collect();
61-
Ok(ArgumentValue::LitList(entries?))
64+
Some(Ok(ArgumentValue::LitList(entries?)))
65+
} else {
66+
None
67+
};
68+
}
69+
70+
fn parse_exprlist(input: ParseStream) -> Option<syn::Result<ArgumentValue>> {
71+
let parse = || {
72+
let exprs = input.parse::<syn::ExprArray>()?;
73+
let entries = exprs.elems.iter().cloned().collect_vec();
74+
Ok(ArgumentValue::ExprList(entries))
6275
};
6376

6477
// match on brackets. anything invalid after is an error
@@ -81,7 +94,7 @@ impl Parse for Argument {
8194
input.parse::<syn::token::Eq>()?;
8295

8396
// iterate over the known parse functions for arguments
84-
[parse_typelist, parse_litlist, parse_str]
97+
[parse_typelist, parse_litlist, parse_exprlist, parse_str]
8598
.iter()
8699
.find_map(|f| {
87100
// fork the buffer, so we can rewind if there isnt a match
@@ -110,6 +123,7 @@ impl Argument {
110123
ArgumentValue::TypeList(_) => "type list",
111124
ArgumentValue::LitList(_) => "const list",
112125
ArgumentValue::Str(_) => "string",
126+
ArgumentValue::ExprList(_) => "expression list",
113127
}
114128
}
115129
}
@@ -145,50 +159,96 @@ impl Extract for ArgumentList {
145159
}
146160

147161
impl ArgumentList {
148-
/// consume a paramlist from the argument list that matches the given generic parameter
149-
/// and return it.
150-
/// Returns an error if there is a type mismatch, or if there is not exactly one match
151-
pub fn consume_paramlist(&mut self, gp: &GenericParam) -> syn::Result<Vec<(Ident, Param)>> {
152-
let (g_ident, g_name) = match gp {
153-
GenericParam::Lifetime(lt) => Err(syn::Error::new(
154-
lt.span(),
155-
"Parameterizing lifetimes is not supported",
156-
)),
157-
GenericParam::Type(t) => Ok((&t.ident, "type")),
158-
GenericParam::Const(c) => Ok((&c.ident, "const")),
159-
}?;
162+
fn consume_paramlist(
163+
&mut self,
164+
ident: &Ident,
165+
ty: &ParamType,
166+
) -> syn::Result<Vec<(Ident, Param)>> {
160167
self.extract_map(|arg| -> Option<syn::Result<Vec<(Ident, Param)>>> {
161-
return if &arg.ident == g_ident {
162-
match (&arg.value, gp) {
163-
(ArgumentValue::TypeList(tl), GenericParam::Type(_)) => Some(
168+
return if &arg.ident == ident {
169+
match (&arg.value, ty) {
170+
(ArgumentValue::TypeList(tl), ParamType::GenericType) => Some(
164171
Ok(tl.iter()
165-
.map(|ty| (arg.ident.clone(), Param::Type(ty.clone())))
172+
.map(|ty| (arg.ident.clone(), Param::GenericType(ty.clone())))
166173
.collect()),
167174
),
168-
(ArgumentValue::LitList(ll), GenericParam::Const(_)) => Some(
175+
(ArgumentValue::LitList(ll), ParamType::GenericConst) => Some(
169176
Ok(ll.iter()
170-
.map(|lit| (arg.ident.clone(), Param::Lit(lit.clone())))
177+
.map(|lit| (arg.ident.clone(), Param::GenericConst(lit.clone())))
178+
.collect()),
179+
),
180+
(ArgumentValue::LitList(ll), ParamType::FnArg) => Some(
181+
Ok(ll.iter()
182+
.map(|lit| (arg.ident.clone(), Param::FnArg(
183+
Expr::Lit (ExprLit{ attrs: vec![], lit: lit.clone() }))))
184+
.collect()),
185+
),
186+
(ArgumentValue::ExprList(el), ParamType::FnArg) => Some(
187+
Ok(el.iter()
188+
.map(|e| (arg.ident.clone(), Param::FnArg(e.clone())))
171189
.collect()),
172190
),
173191
(ArgumentValue::TypeList(_), _) | (ArgumentValue::LitList(_), _) => Some(Err(syn::Error::new(
174192
arg.ident.span(),
175-
format!("Mismatched parameterization: Expected {} list but found {}", g_name, arg.short_type()),
193+
format!("Mismatched parameterization: Expected {} list but found {}", ty, arg.short_type()),
176194
))),
177195
/* fall through, in case theres a generic argument named for example "fmt". there probably shouldn't be though */
178-
(_, _) => None }
196+
(_, _) => None
197+
}
179198
} else {
180199
None
181200
}
182201
})
183-
.at_most_one()
184-
.map_err(|_| {
185-
// more than one match
186-
syn::Error::new(
187-
Span::call_site(),
188-
format!("Multiple {g_name} parameterizations provided for `{g_ident}`"),
189-
)
190-
})?
191-
.ok_or(syn::Error::new(Span::call_site(), format!("No {g_name} parameterization provided for `{g_ident}`")))?
202+
.at_most_one()
203+
.map_err(|_| {
204+
// more than one match
205+
syn::Error::new(
206+
Span::call_site(),
207+
format!("Multiple {ty} parameterizations provided for `{ident}`"),
208+
)
209+
})?
210+
.ok_or(syn::Error::new(Span::call_site(), format!("No {ty} parameterization provided for `{ident}`")))?
192211
// no matches
193212
}
213+
/// consume a paramlist from the argument list that matches the given generic parameter
214+
/// and return it.
215+
/// Returns an error if there is a type mismatch, or if there is not exactly one match
216+
pub fn consume_generic_paramlist(
217+
&mut self,
218+
gp: &GenericParam,
219+
) -> syn::Result<Vec<(Ident, Param)>> {
220+
let (ident, ty) = match gp {
221+
GenericParam::Lifetime(lt) => Err(syn::Error::new(
222+
lt.span(),
223+
"Parameterizing lifetimes is not supported",
224+
)),
225+
GenericParam::Type(t) => Ok((&t.ident, ParamType::GenericType)),
226+
GenericParam::Const(c) => Ok((&c.ident, ParamType::GenericConst)),
227+
}?;
228+
self.consume_paramlist(ident, &ty)
229+
}
230+
231+
/// consume a paramlist from the argument list that matches the given function argument and return it.
232+
/// Returns an error if there is not exactly one match
233+
pub fn consume_arg_paramlist(&mut self, arg: &FnArg) -> syn::Result<Vec<(Ident, Param)>> {
234+
let pat = match arg {
235+
FnArg::Receiver(_) => {
236+
return Err(syn::Error::new(
237+
arg.span(),
238+
"self arguments are not supported",
239+
))
240+
}
241+
FnArg::Typed(pat) => *(pat.clone().pat),
242+
};
243+
let ident = match pat {
244+
Pat::Ident(pat_ident) => pat_ident.ident,
245+
_ => {
246+
return Err(syn::Error::new(
247+
pat.span(),
248+
"function arguments must be an identity pattern",
249+
))
250+
}
251+
};
252+
self.consume_paramlist(&ident, &ParamType::FnArg)
253+
}
194254
}

src/lib.rs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@ fn parameterize_impl(mut args: ArgumentList, mut inner: ItemFn) -> syn::Result<T
5353
.generics
5454
.params
5555
.iter()
56-
.map(|gp| args.consume_paramlist(gp))
56+
.map(|gp| args.consume_generic_paramlist(gp))
57+
.collect::<syn::Result<_>>()?;
58+
59+
let arg_lists: Vec<_> = inner
60+
.sig
61+
.inputs
62+
.iter()
63+
.map(|ident| args.consume_arg_paramlist(ident))
5764
.collect::<syn::Result<_>>()?;
5865

5966
// Consume format string argument
@@ -80,13 +87,29 @@ fn parameterize_impl(mut args: ArgumentList, mut inner: ItemFn) -> syn::Result<T
8087
// iterate over them, and map them to wrapper functions
8188
let (wrapper_idents, wrappers): (Vec<_>, Vec<_>) = param_lists
8289
.iter()
90+
.chain(arg_lists.iter())
8391
.multi_cartesian_product()
8492
.map(|params| {
85-
let param_values = params.iter().map(|(_, p)| p).collect_vec();
93+
let generic_param_values = params
94+
.iter()
95+
.filter_map(|(_, p)| match p {
96+
Param::FnArg(_) => None,
97+
_ => Some(p),
98+
})
99+
.collect_vec();
100+
101+
let arg_values = params
102+
.iter()
103+
.filter_map(|(_, p)| match p {
104+
Param::FnArg(arg) => Some(arg),
105+
_ => None,
106+
})
107+
.collect_vec();
86108

87109
// let fn_ident = format_ident!("{}_{}", inner.sig.ident, param_values.iter().join("_"));
88110
let fn_ident = format_params(&fmt_string, &inner_ident, params);
89-
let fn_body: Expr = syn::parse_quote!(#inner_ident::<#(#param_values,)*>());
111+
let fn_body: Expr =
112+
syn::parse_quote!(#inner_ident::<#(#generic_param_values,)*>(#(#arg_values,)*));
90113
let fn_doc = format!(" Wrapper for {}", fn_body.to_token_stream());
91114
let mut func: ItemFn = syn::parse_quote! {
92115
#[allow(non_snake_case)]
@@ -196,8 +219,5 @@ pub fn parameterize(args: TokenStream, input: TokenStream) -> TokenStream {
196219
let inner = parse_macro_input!(input as syn::ItemFn);
197220
let args = parse_macro_input!(args as ArgumentList);
198221

199-
match parameterize_impl(args, inner) {
200-
Ok(output) => output,
201-
Err(err) => err.to_compile_error().into(),
202-
}
222+
parameterize_impl(args, inner).unwrap_or_else(|err| err.to_compile_error().into())
203223
}

src/params.rs

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,35 @@
77
use proc_macro2::TokenStream;
88
use quote::ToTokens;
99
use std::fmt::{Display, Formatter};
10+
use std::hash::{DefaultHasher, Hash, Hasher};
11+
use strum::EnumDiscriminants;
1012
use syn::{Expr, Lit, Type};
1113

12-
#[derive(Clone, Debug)]
14+
/// A Param is a value we are parameterizing over. It can be a type or a literal
15+
#[derive(Clone, Debug, EnumDiscriminants)]
16+
#[strum_discriminants(name(ParamType))]
1317
pub(crate) enum Param {
14-
Type(Type),
15-
Lit(Lit),
18+
GenericType(Type),
19+
GenericConst(Lit),
20+
FnArg(Expr),
21+
}
22+
23+
impl Display for ParamType {
24+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25+
match self {
26+
ParamType::GenericType => f.write_str("type generic"),
27+
ParamType::GenericConst => f.write_str("const generic"),
28+
ParamType::FnArg => f.write_str("argument"),
29+
}
30+
}
1631
}
1732

1833
impl ToTokens for Param {
1934
fn to_tokens(&self, tokens: &mut TokenStream) {
2035
match self {
21-
Param::Type(ty) => ty.to_tokens(tokens),
22-
Param::Lit(lit) => lit.to_tokens(tokens),
36+
Param::GenericType(ty) => ty.to_tokens(tokens),
37+
Param::GenericConst(lit) => lit.to_tokens(tokens),
38+
Param::FnArg(expr) => expr.to_tokens(tokens),
2339
}
2440
}
2541
}
@@ -30,10 +46,19 @@ impl Display for Param {
3046
}
3147
}
3248

49+
fn calculate_hash<T: Hash>(t: &T) -> String {
50+
let mut s = DefaultHasher::new();
51+
t.hash(&mut s);
52+
let hex = format!("{:X}", s.finish());
53+
return hex[0..4].to_string();
54+
}
55+
3356
impl Param {
3457
fn ident_safe(&self) -> String {
3558
fn lit_ident_safe(lit: &Lit) -> String {
36-
lit.to_token_stream().to_string().replace(".", "p")
59+
let lit = lit.to_token_stream().to_string().replace(".", "p");
60+
let lit = lit.replace("\"", "");
61+
lit
3762
}
3863
fn type_ident_safe(ty: &Type) -> String {
3964
match ty {
@@ -64,8 +89,10 @@ impl Param {
6489
}
6590
}
6691
match self {
67-
Param::Type(ty) => type_ident_safe(ty),
68-
Param::Lit(lit) => lit_ident_safe(lit),
92+
Param::GenericType(ty) => type_ident_safe(ty),
93+
Param::GenericConst(lit) => lit_ident_safe(lit),
94+
Param::FnArg(Expr::Lit(lit)) => lit_ident_safe(&lit.lit),
95+
Param::FnArg(expr) => format!("expr{}", calculate_hash(expr)),
6996
}
7097
}
7198
}

0 commit comments

Comments
 (0)