Skip to content

Commit c43241b

Browse files
committed
feat(query): tighten procedure overload resolution
1 parent 67548ed commit c43241b

File tree

3 files changed

+503
-32
lines changed

3 files changed

+503
-32
lines changed

src/query/management/src/procedure/procedure_mgr.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,4 +201,25 @@ impl ProcedureMgr {
201201

202202
Ok(procedure_infos)
203203
}
204+
205+
#[fastrace::trace]
206+
pub async fn list_procedures_by_name(
207+
&self,
208+
name: &str,
209+
) -> Result<Vec<ProcedureInfo>, MetaError> {
210+
debug!(name = (name); "SchemaApi: {}", func_name!());
211+
let ident = ProcedureNameIdent::new(&self.tenant, ProcedureIdentity::new(name, ""));
212+
let dir = DirName::new_with_level(ident, 1);
213+
214+
let name_id_metas = self.kv_api.list_id_value(&dir).await?;
215+
let procedure_infos = name_id_metas
216+
.map(|(k, id, seq_meta)| ProcedureInfo {
217+
ident: ProcedureIdIdent::new(&self.tenant, *id),
218+
name_ident: k,
219+
meta: seq_meta.data,
220+
})
221+
.collect::<Vec<_>>();
222+
223+
Ok(procedure_infos)
224+
}
204225
}

src/query/sql/src/planner/binder/ddl/procedure.rs

Lines changed: 162 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,27 @@ use databend_common_ast::ast::CreateProcedureStmt;
1818
use databend_common_ast::ast::DescProcedureStmt;
1919
use databend_common_ast::ast::DropProcedureStmt;
2020
use databend_common_ast::ast::ExecuteImmediateStmt;
21+
use databend_common_ast::ast::Expr;
2122
use databend_common_ast::ast::ProcedureIdentity as AstProcedureIdentity;
2223
use databend_common_ast::ast::ProcedureLanguage;
2324
use databend_common_ast::ast::ProcedureType;
2425
use databend_common_ast::ast::ShowOptions;
26+
use databend_common_ast::ast::TypeName;
27+
use databend_common_ast::parser::expr::type_name as parse_type_name_ast;
2528
use databend_common_ast::parser::run_parser;
2629
use databend_common_ast::parser::script::script_block_or_stmt;
2730
use databend_common_ast::parser::script::ScriptBlockOrStmt;
2831
use databend_common_ast::parser::tokenize_sql;
32+
use databend_common_ast::parser::Dialect;
2933
use databend_common_ast::parser::ParseMode;
3034
use databend_common_exception::ErrorCode;
3135
use databend_common_exception::Result;
36+
use databend_common_expression::type_check::common_super_type;
3237
use databend_common_expression::types::DataType;
3338
use databend_common_expression::Scalar;
39+
use databend_common_functions::BUILTIN_FUNCTIONS;
40+
use databend_common_meta_app::principal::procedure::ProcedureInfo;
41+
use databend_common_meta_app::principal::GetProcedureReply;
3442
use databend_common_meta_app::principal::GetProcedureReq;
3543
use databend_common_meta_app::principal::ProcedureIdentity;
3644
use databend_common_meta_app::principal::ProcedureMeta;
@@ -180,39 +188,84 @@ impl Binder {
180188
&[],
181189
true,
182190
)?;
183-
let mut arg_types = vec![];
191+
let mut arg_types = Vec::with_capacity(arguments.len());
184192
for argument in arguments {
185193
let box (arg, mut arg_type) = type_checker.resolve(argument)?;
186194
if let ScalarExpr::SubqueryExpr(subquery) = &arg {
187195
if subquery.typ == SubqueryType::Scalar && !arg.data_type()?.is_nullable() {
188196
arg_type = arg_type.wrap_nullable();
189197
}
190198
}
191-
arg_types.push(arg_type.to_string());
199+
arg_types.push(arg_type);
192200
}
193-
let name = name.to_string();
194-
let procedure_ident = ProcedureIdentity::new(name, arg_types.join(","));
201+
202+
let name_str = name.to_string();
203+
let procedure_api = UserApiProvider::instance().procedure_api(&tenant);
204+
205+
// Try exact match first
206+
let arg_type_strings: Vec<String> = arg_types.iter().map(|t| t.to_string()).collect();
207+
let procedure_ident = ProcedureIdentity::new(name_str.clone(), arg_type_strings.join(","));
195208
let req = GetProcedureReq {
196209
inner: ProcedureNameIdent::new(tenant.clone(), procedure_ident.clone()),
197210
};
198-
199-
let procedure = UserApiProvider::instance()
200-
.procedure_api(&tenant)
201-
.get_procedure(&req)
202-
.await?;
203-
if let Some(procedure) = procedure {
204-
Ok(Plan::CallProcedure(Box::new(CallProcedurePlan {
211+
if let Some(procedure) = procedure_api.get_procedure(&req).await? {
212+
return Ok(Plan::CallProcedure(Box::new(CallProcedurePlan {
205213
procedure_id: procedure.id,
206214
script: procedure.procedure_meta.script,
207215
arg_names: procedure.procedure_meta.arg_names,
208216
args: arguments.clone(),
209-
})))
210-
} else {
211-
Err(ErrorCode::UnknownProcedure(format!(
217+
})));
218+
}
219+
220+
// Exact match failed, try implicit cast resolution.
221+
let candidates = procedure_api.list_procedures_by_name(&name_str).await?;
222+
let mut same_arity_candidates = Vec::new();
223+
for candidate in candidates {
224+
let arg_defs = parse_procedure_signature(&candidate.name_ident.procedure_name().args)?;
225+
if arg_defs.len() == arg_types.len() {
226+
same_arity_candidates.push(candidate);
227+
}
228+
}
229+
230+
let has_explicit_cast = arguments
231+
.iter()
232+
.any(|expr| matches!(expr, Expr::Cast { .. } | Expr::TryCast { .. }));
233+
234+
// Multiple overloads plus lack of explicit casts means we cannot disambiguate.
235+
if same_arity_candidates.len() > 1 && !has_explicit_cast {
236+
return Err(ErrorCode::UnknownProcedure(format!(
212237
"Unknown procedure {}",
213238
procedure_ident
214-
)))
239+
)));
215240
}
241+
242+
let allow_implicit_cast = same_arity_candidates.len() == 1 && !has_explicit_cast;
243+
let (procedure, casts_to_apply) = resolve_procedure_candidate(
244+
&procedure_ident,
245+
&arg_types,
246+
same_arity_candidates,
247+
allow_implicit_cast,
248+
)?;
249+
250+
let args = arguments
251+
.iter()
252+
.zip(casts_to_apply.into_iter())
253+
.map(|(expr, cast)| match cast {
254+
Some(target_type) => Expr::Cast {
255+
span: expr.span(),
256+
expr: Box::new(expr.clone()),
257+
target_type,
258+
pg_style: false,
259+
},
260+
None => expr.clone(),
261+
})
262+
.collect();
263+
Ok(Plan::CallProcedure(Box::new(CallProcedurePlan {
264+
procedure_id: procedure.id,
265+
script: procedure.procedure_meta.script,
266+
arg_names: procedure.procedure_meta.arg_names,
267+
args,
268+
})))
216269
}
217270

218271
fn procedure_meta(
@@ -271,7 +324,6 @@ fn generate_procedure_name_ident(
271324
.map(|type_name| resolve_type_name(type_name, true).map(|t| DataType::from(&t)))
272325
.collect::<Result<Vec<_>, _>>()?;
273326

274-
// Convert normalized DataType back to string for storage
275327
let args_type_str = args_data_type
276328
.iter()
277329
.map(|dt| dt.to_string())
@@ -283,3 +335,97 @@ fn generate_procedure_name_ident(
283335
ProcedureIdentity::new(name.name.clone(), args_type_str),
284336
))
285337
}
338+
339+
/// Find the first procedure overload whose signature is compatible with the
340+
/// actual argument types, optionally allowing implicit casts. Returns the
341+
/// procedure metadata plus the exact casts we need to inject.
342+
fn resolve_procedure_candidate(
343+
procedure_ident: &ProcedureIdentity,
344+
arg_types: &[DataType],
345+
candidates: Vec<ProcedureInfo>,
346+
allow_implicit_cast: bool,
347+
) -> Result<(GetProcedureReply, Vec<Option<TypeName>>)> {
348+
for candidate in candidates {
349+
let arg_defs = parse_procedure_signature(&candidate.name_ident.procedure_name().args)?;
350+
if arg_defs.len() != arg_types.len() {
351+
continue;
352+
}
353+
354+
let mut casts = Vec::with_capacity(arg_types.len());
355+
let mut compatible = true;
356+
for (actual, target_ast) in arg_types.iter().zip(arg_defs.iter()) {
357+
let target = DataType::from(&resolve_type_name(target_ast, true)?);
358+
if actual == &target {
359+
casts.push(None);
360+
continue;
361+
}
362+
363+
if allow_implicit_cast
364+
&& common_super_type(
365+
actual.clone(),
366+
target.clone(),
367+
&BUILTIN_FUNCTIONS.default_cast_rules,
368+
)
369+
.is_some_and(|common| common == target)
370+
{
371+
casts.push(Some(target_ast.clone()));
372+
} else {
373+
compatible = false;
374+
break;
375+
}
376+
}
377+
378+
if compatible {
379+
return Ok((
380+
GetProcedureReply {
381+
id: *candidate.ident.procedure_id(),
382+
procedure_meta: candidate.meta.clone(),
383+
},
384+
casts,
385+
));
386+
}
387+
}
388+
389+
Err(ErrorCode::UnknownProcedure(format!(
390+
"Unknown procedure {}",
391+
procedure_ident
392+
)))
393+
}
394+
395+
fn parse_procedure_signature(arg_str: &str) -> Result<Vec<TypeName>> {
396+
if arg_str.is_empty() {
397+
return Ok(vec![]);
398+
}
399+
400+
let mut segments = Vec::new();
401+
let mut depth = 0i32;
402+
let mut start = 0usize;
403+
for (idx, ch) in arg_str.char_indices() {
404+
match ch {
405+
'(' => depth += 1,
406+
')' => depth -= 1,
407+
',' if depth == 0 => {
408+
// Only split on commas at depth 0 so nested args (like DECIMAL) stay intact.
409+
segments.push(arg_str[start..idx].trim());
410+
start = idx + 1;
411+
}
412+
_ => {}
413+
}
414+
}
415+
segments.push(arg_str[start..].trim());
416+
417+
segments
418+
.into_iter()
419+
.map(|segment| {
420+
let tokens = tokenize_sql(segment)?;
421+
run_parser(
422+
&tokens,
423+
Dialect::default(),
424+
ParseMode::Default,
425+
false,
426+
parse_type_name_ast,
427+
)
428+
.map_err(|e| ErrorCode::SyntaxException(e.to_string()))
429+
})
430+
.collect()
431+
}

0 commit comments

Comments
 (0)