Skip to content

Commit 77e405f

Browse files
author
Alexander Beedie
committed
Fix derive_dialect! proc macro for use from external crates
1 parent 2ea773a commit 77e405f

File tree

1 file changed

+69
-26
lines changed

1 file changed

+69
-26
lines changed

derive/src/dialect.rs

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -120,24 +120,20 @@ impl Parse for DeriveDialectInput {
120120

121121
/// Entry point for the `derive_dialect!` macro
122122
pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStream {
123-
let err = |msg: String| {
124-
Error::new(proc_macro2::Span::call_site(), msg)
125-
.to_compile_error()
126-
.into()
127-
};
123+
match derive_dialect_inner(input) {
124+
Ok(tokens) => tokens.into(),
125+
Err(e) => e.to_compile_error().into(),
126+
}
127+
}
128128

129-
let source = match read_dialect_mod_file() {
130-
Ok(s) => s,
131-
Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")),
132-
};
133-
let file: File = match syn::parse_str(&source) {
134-
Ok(f) => f,
135-
Err(e) => return err(format!("Failed to parse source: {e}")),
136-
};
137-
let methods = match extract_dialect_methods(&file) {
138-
Ok(m) => m,
139-
Err(e) => return e.to_compile_error().into(),
140-
};
129+
fn derive_dialect_inner(input: DeriveDialectInput) -> syn::Result<TokenStream> {
130+
let call_site = proc_macro2::Span::call_site();
131+
132+
let source = read_dialect_mod_file()
133+
.map_err(|e| Error::new(call_site, format!("Failed to read dialect/mod.rs: {e}")))?;
134+
let file: File = syn::parse_str::<File>(&source)
135+
.map_err(|e| Error::new(call_site, format!("Failed to parse source: {e}")))?;
136+
let methods = extract_dialect_methods(&file)?;
141137

142138
// Validate overrides
143139
let bool_names: HashSet<_> = methods
@@ -147,20 +143,23 @@ pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStre
147143
.collect();
148144
for (key, value) in &input.overrides {
149145
let key_str = key.to_string();
150-
let err = |msg| Error::new(key.span(), msg).to_compile_error().into();
151146
match value {
152147
Override::Bool(_) if !bool_names.contains(&key_str) => {
153-
return err(format!("Unknown boolean method `{key_str}`"));
148+
return Err(Error::new(
149+
key.span(),
150+
format!("Unknown boolean method `{key_str}`"),
151+
));
154152
}
155153
Override::Char(_) | Override::None if key_str != "identifier_quote_style" => {
156-
return err(format!(
157-
"Char/None only valid for `identifier_quote_style`, not `{key_str}`"
154+
return Err(Error::new(
155+
key.span(),
156+
format!("Char/None only valid for `identifier_quote_style`, not `{key_str}`"),
158157
));
159158
}
160159
_ => {}
161160
}
162161
}
163-
generate_derived_dialect(&input, &methods).into()
162+
Ok(generate_derived_dialect(&input, &methods))
164163
}
165164

166165
/// Generate the complete derived `Dialect` implementation
@@ -258,11 +257,55 @@ fn extract_param_names(sig: &Signature) -> Vec<&Ident> {
258257
}
259258

260259
/// Read the `dialect/mod.rs` file that contains the Dialect trait.
260+
///
261+
/// Searches for the file in the following order:
262+
/// 1. `$CARGO_MANIFEST_DIR/src/dialect/mod.rs` - works when the macro is
263+
/// invoked from within the `sqlparser` crate itself (e.g. in tests).
264+
/// 2. `<sqlparser_derive dir>/../src/dialect/mod.rs` - works when
265+
/// `sqlparser_derive` lives in a workspace alongside the main crate
266+
/// (the standard `derive/` layout).
267+
/// 3. Sibling directories of the compiled `sqlparser_derive` crate in the
268+
/// Cargo registry - works when an external crate uses `derive_dialect!`
269+
/// via a registry dependency.
261270
fn read_dialect_mod_file() -> Result<String, String> {
262-
let manifest_dir =
263-
std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set")?;
264-
let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs");
265-
std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}: {e}", path.display()))
271+
use std::path::{Path, PathBuf};
272+
273+
const DERIVE_CRATE_DIR: &str = env!("CARGO_MANIFEST_DIR");
274+
let derive_dir = Path::new(DERIVE_CRATE_DIR);
275+
let mut candidates: Vec<PathBuf> = Vec::new();
276+
277+
// The crate being compiled (eg: within sqlparser).
278+
if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
279+
candidates.push(Path::new(&manifest_dir).join("src/dialect/mod.rs"));
280+
}
281+
// Workspace layout: the main crate is the parent of `derive/`.
282+
candidates.push(derive_dir.join("../src/dialect/mod.rs"));
283+
284+
// Cargo registry: look for sibling `sqlparser-*` directories (prefer newest).
285+
if let Some(parent) = derive_dir.parent() {
286+
if let Ok(entries) = std::fs::read_dir(parent) {
287+
let mut siblings: Vec<_> = entries
288+
.filter_map(|e| e.ok())
289+
.filter(|e| {
290+
let name = e.file_name();
291+
let name = name.to_string_lossy();
292+
name.starts_with("sqlparser-") && !name.starts_with("sqlparser-derive")
293+
})
294+
.collect();
295+
siblings.sort_by(|a, b| b.file_name().cmp(&a.file_name()));
296+
candidates.extend(siblings.into_iter().map(|e| e.path().join("src/dialect/mod.rs")));
297+
}
298+
}
299+
for path in &candidates {
300+
if let Ok(content) = std::fs::read_to_string(path) {
301+
return Ok(content);
302+
}
303+
}
304+
Err(format!(
305+
"Could not find `sqlparser` dialect/mod.rs file. \
306+
Searched in $CARGO_MANIFEST_DIR/src/dialect/mod.rs and \
307+
the `sqlparser_derive` crate at {DERIVE_CRATE_DIR}"
308+
))
266309
}
267310

268311
/// Extract all methods from the `Dialect` trait (excluding `dialect` for TypeId)

0 commit comments

Comments
 (0)