diff --git a/derive/src/dialect.rs b/derive/src/dialect.rs index 9873e4f7b5..9066bf9645 100644 --- a/derive/src/dialect.rs +++ b/derive/src/dialect.rs @@ -120,24 +120,20 @@ impl Parse for DeriveDialectInput { /// Entry point for the `derive_dialect!` macro pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStream { - let err = |msg: String| { - Error::new(proc_macro2::Span::call_site(), msg) - .to_compile_error() - .into() - }; + match derive_dialect_inner(input) { + Ok(tokens) => tokens.into(), + Err(e) => e.to_compile_error().into(), + } +} - let source = match read_dialect_mod_file() { - Ok(s) => s, - Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")), - }; - let file: File = match syn::parse_str(&source) { - Ok(f) => f, - Err(e) => return err(format!("Failed to parse source: {e}")), - }; - let methods = match extract_dialect_methods(&file) { - Ok(m) => m, - Err(e) => return e.to_compile_error().into(), - }; +fn derive_dialect_inner(input: DeriveDialectInput) -> syn::Result { + let call_site = proc_macro2::Span::call_site(); + + let source = read_dialect_mod_file() + .map_err(|e| Error::new(call_site, format!("Failed to read dialect/mod.rs: {e}")))?; + let file: File = syn::parse_str::(&source) + .map_err(|e| Error::new(call_site, format!("Failed to parse source: {e}")))?; + let methods = extract_dialect_methods(&file)?; // Validate overrides let bool_names: HashSet<_> = methods @@ -147,20 +143,23 @@ pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStre .collect(); for (key, value) in &input.overrides { let key_str = key.to_string(); - let err = |msg| Error::new(key.span(), msg).to_compile_error().into(); match value { Override::Bool(_) if !bool_names.contains(&key_str) => { - return err(format!("Unknown boolean method `{key_str}`")); + return Err(Error::new( + key.span(), + format!("Unknown boolean method `{key_str}`"), + )); } Override::Char(_) | Override::None if key_str != "identifier_quote_style" => { - return err(format!( - "Char/None only valid for `identifier_quote_style`, not `{key_str}`" + return Err(Error::new( + key.span(), + format!("Char/None only valid for `identifier_quote_style`, not `{key_str}`"), )); } _ => {} } } - generate_derived_dialect(&input, &methods).into() + Ok(generate_derived_dialect(&input, &methods)) } /// Generate the complete derived `Dialect` implementation @@ -258,11 +257,59 @@ fn extract_param_names(sig: &Signature) -> Vec<&Ident> { } /// Read the `dialect/mod.rs` file that contains the Dialect trait. +/// +/// Searches for the file in the following order: +/// 1. `$CARGO_MANIFEST_DIR/src/dialect/mod.rs` - works when the macro is +/// invoked from within the `sqlparser` crate itself (e.g. in tests). +/// 2. `/../src/dialect/mod.rs` - works when +/// `sqlparser_derive` lives in a workspace alongside the main crate +/// (the standard `derive/` layout). +/// 3. Sibling directories of the compiled `sqlparser_derive` crate in the +/// Cargo registry - works when an external crate uses `derive_dialect!` +/// via a registry dependency. fn read_dialect_mod_file() -> Result { - let manifest_dir = - std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set")?; - let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs"); - std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}: {e}", path.display())) + use std::path::{Path, PathBuf}; + + const DERIVE_CRATE_DIR: &str = env!("CARGO_MANIFEST_DIR"); + let derive_dir = Path::new(DERIVE_CRATE_DIR); + let mut candidates: Vec = Vec::new(); + + // The crate being compiled (eg: within sqlparser). + if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") { + candidates.push(Path::new(&manifest_dir).join("src/dialect/mod.rs")); + } + // Workspace layout: the main crate is the parent of `derive/`. + candidates.push(derive_dir.join("../src/dialect/mod.rs")); + + // Cargo registry: look for sibling `sqlparser-*` directories (prefer newest). + if let Some(parent) = derive_dir.parent() { + if let Ok(entries) = std::fs::read_dir(parent) { + let mut siblings: Vec<_> = entries + .filter_map(|e| e.ok()) + .filter(|e| { + let name = e.file_name(); + let name = name.to_string_lossy(); + name.starts_with("sqlparser-") && !name.starts_with("sqlparser-derive") + }) + .collect(); + siblings.sort_by(|a, b| b.file_name().cmp(&a.file_name())); + candidates.extend( + siblings + .into_iter() + .map(|e| e.path().join("src/dialect/mod.rs")), + ); + } + } + for path in &candidates { + if let Ok(content) = std::fs::read_to_string(path) { + return Ok(content); + } + } + Err(format!( + "Could not find `sqlparser` dialect/mod.rs file. \ + Searched in $CARGO_MANIFEST_DIR/src/dialect/mod.rs and \ + the `sqlparser_derive` crate at {DERIVE_CRATE_DIR}" + )) } /// Extract all methods from the `Dialect` trait (excluding `dialect` for TypeId)