Skip to content

Commit a39617e

Browse files
committed
feat(query): tighten procedure overload resolution
1 parent 67548ed commit a39617e

File tree

3 files changed

+465
-32
lines changed

3 files changed

+465
-32
lines changed

src/query/management/src/procedure/procedure_mgr.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,4 +201,25 @@ impl ProcedureMgr {
201201

202202
Ok(procedure_infos)
203203
}
204+
205+
#[fastrace::trace]
206+
pub async fn list_procedures_by_name(
207+
&self,
208+
name: &str,
209+
) -> Result<Vec<ProcedureInfo>, MetaError> {
210+
debug!(name = (name); "SchemaApi: {}", func_name!());
211+
let ident = ProcedureNameIdent::new(&self.tenant, ProcedureIdentity::new(name, ""));
212+
let dir = DirName::new_with_level(ident, 1);
213+
214+
let name_id_metas = self.kv_api.list_id_value(&dir).await?;
215+
let procedure_infos = name_id_metas
216+
.map(|(k, id, seq_meta)| ProcedureInfo {
217+
ident: ProcedureIdIdent::new(&self.tenant, *id),
218+
name_ident: k,
219+
meta: seq_meta.data,
220+
})
221+
.collect::<Vec<_>>();
222+
223+
Ok(procedure_infos)
224+
}
204225
}

src/query/sql/src/planner/binder/ddl/procedure.rs

Lines changed: 155 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,27 @@ use databend_common_ast::ast::CreateProcedureStmt;
1818
use databend_common_ast::ast::DescProcedureStmt;
1919
use databend_common_ast::ast::DropProcedureStmt;
2020
use databend_common_ast::ast::ExecuteImmediateStmt;
21+
use databend_common_ast::ast::Expr;
2122
use databend_common_ast::ast::ProcedureIdentity as AstProcedureIdentity;
2223
use databend_common_ast::ast::ProcedureLanguage;
2324
use databend_common_ast::ast::ProcedureType;
2425
use databend_common_ast::ast::ShowOptions;
26+
use databend_common_ast::ast::TypeName;
27+
use databend_common_ast::parser::expr::type_name as parse_type_name_ast;
2528
use databend_common_ast::parser::run_parser;
2629
use databend_common_ast::parser::script::script_block_or_stmt;
2730
use databend_common_ast::parser::script::ScriptBlockOrStmt;
2831
use databend_common_ast::parser::tokenize_sql;
32+
use databend_common_ast::parser::Dialect;
2933
use databend_common_ast::parser::ParseMode;
3034
use databend_common_exception::ErrorCode;
3135
use databend_common_exception::Result;
36+
use databend_common_expression::type_check::common_super_type;
3237
use databend_common_expression::types::DataType;
3338
use databend_common_expression::Scalar;
39+
use databend_common_functions::BUILTIN_FUNCTIONS;
40+
use databend_common_meta_app::principal::procedure::ProcedureInfo;
41+
use databend_common_meta_app::principal::GetProcedureReply;
3442
use databend_common_meta_app::principal::GetProcedureReq;
3543
use databend_common_meta_app::principal::ProcedureIdentity;
3644
use databend_common_meta_app::principal::ProcedureMeta;
@@ -180,39 +188,77 @@ impl Binder {
180188
&[],
181189
true,
182190
)?;
183-
let mut arg_types = vec![];
191+
let mut arg_types = Vec::with_capacity(arguments.len());
184192
for argument in arguments {
185193
let box (arg, mut arg_type) = type_checker.resolve(argument)?;
186194
if let ScalarExpr::SubqueryExpr(subquery) = &arg {
187195
if subquery.typ == SubqueryType::Scalar && !arg.data_type()?.is_nullable() {
188196
arg_type = arg_type.wrap_nullable();
189197
}
190198
}
191-
arg_types.push(arg_type.to_string());
199+
arg_types.push(arg_type);
192200
}
193-
let name = name.to_string();
194-
let procedure_ident = ProcedureIdentity::new(name, arg_types.join(","));
201+
202+
let name_str = name.to_string();
203+
let procedure_api = UserApiProvider::instance().procedure_api(&tenant);
204+
205+
// Try exact match first
206+
let arg_type_strings: Vec<String> = arg_types.iter().map(|t| t.to_string()).collect();
207+
let procedure_ident = ProcedureIdentity::new(name_str.clone(), arg_type_strings.join(","));
195208
let req = GetProcedureReq {
196209
inner: ProcedureNameIdent::new(tenant.clone(), procedure_ident.clone()),
197210
};
198-
199-
let procedure = UserApiProvider::instance()
200-
.procedure_api(&tenant)
201-
.get_procedure(&req)
202-
.await?;
203-
if let Some(procedure) = procedure {
204-
Ok(Plan::CallProcedure(Box::new(CallProcedurePlan {
211+
if let Some(procedure) = procedure_api.get_procedure(&req).await? {
212+
return Ok(Plan::CallProcedure(Box::new(CallProcedurePlan {
205213
procedure_id: procedure.id,
206214
script: procedure.procedure_meta.script,
207215
arg_names: procedure.procedure_meta.arg_names,
208216
args: arguments.clone(),
209-
})))
210-
} else {
211-
Err(ErrorCode::UnknownProcedure(format!(
217+
})));
218+
}
219+
220+
// Exact match failed, try implicit cast resolution.
221+
let candidates = procedure_api.list_procedures_by_name(&name_str).await?;
222+
223+
let has_explicit_cast = arguments
224+
.iter()
225+
.any(|expr| matches!(expr, Expr::Cast { .. } | Expr::TryCast { .. }));
226+
227+
// Multiple overloads plus lack of explicit casts means we cannot disambiguate.
228+
if candidates.len() > 1 && !has_explicit_cast {
229+
return Err(ErrorCode::UnknownProcedure(format!(
212230
"Unknown procedure {}",
213231
procedure_ident
214-
)))
232+
)));
215233
}
234+
235+
let allow_implicit_cast = candidates.len() == 1 && !has_explicit_cast;
236+
let (procedure, casts_to_apply) = resolve_procedure_candidate(
237+
&procedure_ident,
238+
&arg_types,
239+
candidates,
240+
allow_implicit_cast,
241+
)?;
242+
243+
let args = arguments
244+
.iter()
245+
.zip(casts_to_apply.into_iter())
246+
.map(|(expr, cast)| match cast {
247+
Some(target_type) => Expr::Cast {
248+
span: expr.span(),
249+
expr: Box::new(expr.clone()),
250+
target_type,
251+
pg_style: false,
252+
},
253+
None => expr.clone(),
254+
})
255+
.collect();
256+
Ok(Plan::CallProcedure(Box::new(CallProcedurePlan {
257+
procedure_id: procedure.id,
258+
script: procedure.procedure_meta.script,
259+
arg_names: procedure.procedure_meta.arg_names,
260+
args,
261+
})))
216262
}
217263

218264
fn procedure_meta(
@@ -271,7 +317,6 @@ fn generate_procedure_name_ident(
271317
.map(|type_name| resolve_type_name(type_name, true).map(|t| DataType::from(&t)))
272318
.collect::<Result<Vec<_>, _>>()?;
273319

274-
// Convert normalized DataType back to string for storage
275320
let args_type_str = args_data_type
276321
.iter()
277322
.map(|dt| dt.to_string())
@@ -283,3 +328,97 @@ fn generate_procedure_name_ident(
283328
ProcedureIdentity::new(name.name.clone(), args_type_str),
284329
))
285330
}
331+
332+
/// Find the first procedure overload whose signature is compatible with the
333+
/// actual argument types, optionally allowing implicit casts. Returns the
334+
/// procedure metadata plus the exact casts we need to inject.
335+
fn resolve_procedure_candidate(
336+
procedure_ident: &ProcedureIdentity,
337+
arg_types: &[DataType],
338+
candidates: Vec<ProcedureInfo>,
339+
allow_implicit_cast: bool,
340+
) -> Result<(GetProcedureReply, Vec<Option<TypeName>>)> {
341+
for candidate in candidates {
342+
let arg_defs = parse_procedure_signature(&candidate.name_ident.procedure_name().args)?;
343+
if arg_defs.len() != arg_types.len() {
344+
continue;
345+
}
346+
347+
let mut casts = Vec::with_capacity(arg_types.len());
348+
let mut compatible = true;
349+
for (actual, target_ast) in arg_types.iter().zip(arg_defs.iter()) {
350+
let target = DataType::from(&resolve_type_name(target_ast, true)?);
351+
if actual == &target {
352+
casts.push(None);
353+
continue;
354+
}
355+
356+
if allow_implicit_cast
357+
&& common_super_type(
358+
actual.clone(),
359+
target.clone(),
360+
&BUILTIN_FUNCTIONS.default_cast_rules,
361+
)
362+
.is_some_and(|common| common == target)
363+
{
364+
casts.push(Some(target_ast.clone()));
365+
} else {
366+
compatible = false;
367+
break;
368+
}
369+
}
370+
371+
if compatible {
372+
return Ok((
373+
GetProcedureReply {
374+
id: *candidate.ident.procedure_id(),
375+
procedure_meta: candidate.meta.clone(),
376+
},
377+
casts,
378+
));
379+
}
380+
}
381+
382+
Err(ErrorCode::UnknownProcedure(format!(
383+
"Unknown procedure {}",
384+
procedure_ident
385+
)))
386+
}
387+
388+
fn parse_procedure_signature(arg_str: &str) -> Result<Vec<TypeName>> {
389+
if arg_str.is_empty() {
390+
return Ok(vec![]);
391+
}
392+
393+
let mut segments = Vec::new();
394+
let mut depth = 0i32;
395+
let mut start = 0usize;
396+
for (idx, ch) in arg_str.char_indices() {
397+
match ch {
398+
'(' => depth += 1,
399+
')' => depth -= 1,
400+
',' if depth == 0 => {
401+
// Only split on commas at depth 0 so nested args (like DECIMAL) stay intact.
402+
segments.push(arg_str[start..idx].trim());
403+
start = idx + 1;
404+
}
405+
_ => {}
406+
}
407+
}
408+
segments.push(arg_str[start..].trim());
409+
410+
segments
411+
.into_iter()
412+
.map(|segment| {
413+
let tokens = tokenize_sql(segment)?;
414+
run_parser(
415+
&tokens,
416+
Dialect::default(),
417+
ParseMode::Default,
418+
false,
419+
parse_type_name_ast,
420+
)
421+
.map_err(|e| ErrorCode::SyntaxException(e.to_string()))
422+
})
423+
.collect()
424+
}

0 commit comments

Comments
 (0)