From 66849b054884abaaeb3dadb3fb0ad6b660f77539 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Sep 2025 11:19:57 +0000 Subject: [PATCH 01/15] Checkpoint before follow-up message Co-authored-by: contact --- sqlx-core/src/odbc/connection/executor.rs | 4 +- sqlx-core/src/odbc/connection/mod.rs | 459 ++++++++++++- sqlx-core/src/odbc/connection/worker.rs | 781 ---------------------- sqlx-core/src/odbc/transaction.rs | 6 +- sqlx-rt/src/rt_async_std.rs | 1 + sqlx-rt/src/rt_tokio.rs | 1 + 6 files changed, 448 insertions(+), 804 deletions(-) delete mode 100644 sqlx-core/src/odbc/connection/worker.rs diff --git a/sqlx-core/src/odbc/connection/executor.rs b/sqlx-core/src/odbc/connection/executor.rs index 45a32185f8..b953ef0969 100644 --- a/sqlx-core/src/odbc/connection/executor.rs +++ b/sqlx-core/src/odbc/connection/executor.rs @@ -24,7 +24,7 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection { let sql = query.sql().to_string(); let args = query.take_arguments(); Box::pin(try_stream! { - let rx = self.worker.execute_stream(&sql, args).await?; + let rx = self.execute_stream(&sql, args).await?; while let Ok(item) = rx.recv_async().await { r#yield!(item?); } @@ -60,7 +60,7 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection { 'c: 'e, { Box::pin(async move { - let (_, columns, parameters) = self.worker.prepare(sql).await?; + let (_, columns, parameters) = self.prepare(sql).await?; Ok(OdbcStatement { sql: Cow::Borrowed(sql), columns, diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index fc9751bae0..3ce6bf7029 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -1,43 +1,133 @@ use crate::connection::{Connection, LogSettings}; use crate::error::Error; -use crate::odbc::{Odbc, OdbcConnectOptions}; +use crate::odbc::{Odbc, OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, OdbcTypeInfo}; use crate::transaction::Transaction; +use either::Either; +use flume::SendError; use futures_core::future::BoxFuture; use futures_util::future; +use odbc_api::handles::StatementImpl; +use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata}; +use sqlx_rt::spawn_blocking; +use std::sync::{Arc, Mutex}; mod executor; -mod worker; - -pub(crate) use worker::ConnectionWorker; /// A connection to an ODBC-accessible database. /// -/// ODBC uses a blocking C API, so we run all calls on a dedicated background thread -/// and communicate over channels to provide async access. +/// ODBC uses a blocking C API, so we offload blocking calls to the runtime's blocking +/// thread-pool via `spawn_blocking` and synchronize access with a mutex. #[derive(Debug)] pub struct OdbcConnection { - pub(crate) worker: ConnectionWorker, + pub(crate) inner: Arc>>, pub(crate) log_settings: LogSettings, } impl OdbcConnection { pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result { - let worker = ConnectionWorker::establish(options.clone()).await?; + let conn = spawn_blocking({ + let options = options.clone(); + move || establish_connection(&options) + }) + .await + .map_err(|_| Error::WorkerCrashed)??; + Ok(Self { - worker, + inner: Arc::new(Mutex::new(conn)), log_settings: LogSettings::default(), }) } /// Returns the name of the actual Database Management System (DBMS) this /// connection is talking to as reported by the ODBC driver. - /// - /// This calls the underlying ODBC API `SQL_DBMS_NAME` via - /// `odbc_api::Connection::database_management_system_name`. - /// - /// See: https://docs.rs/odbc-api/19.0.1/odbc_api/struct.Connection.html#method.database_management_system_name pub async fn dbms_name(&mut self) -> Result { - self.worker.get_dbms_name().await + let inner = self.inner.clone(); + spawn_blocking(move || { + let conn = inner.lock().unwrap(); + conn.database_management_system_name() + .map_err(|e| Error::Protocol(format!("Failed to get DBMS name: {}", e))) + }) + .await + .map_err(|_| Error::WorkerCrashed)? + } + + pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> { + let inner = self.inner.clone(); + spawn_blocking(move || { + let conn = inner.lock().unwrap(); + let res = conn.execute("SELECT 1", (), None); + match res { + Ok(_) => Ok(()), + Err(e) => Err(Error::Protocol(format!("Ping failed: {}", e))), + } + }) + .await + .map_err(|_| Error::WorkerCrashed)? + } + + pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> { + let inner = self.inner.clone(); + spawn_blocking(move || { + let conn = inner.lock().unwrap(); + conn.set_autocommit(false) + .map_err(|e| Error::Protocol(format!("Failed to begin transaction: {}", e))) + }) + .await + .map_err(|_| Error::WorkerCrashed)? + } + + pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> { + let inner = self.inner.clone(); + spawn_blocking(move || { + let conn = inner.lock().unwrap(); + conn.commit() + .and_then(|_| conn.set_autocommit(true)) + .map_err(|e| Error::Protocol(format!("Failed to commit transaction: {}", e))) + }) + .await + .map_err(|_| Error::WorkerCrashed)? + } + + pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> { + let inner = self.inner.clone(); + spawn_blocking(move || { + let conn = inner.lock().unwrap(); + conn.rollback() + .and_then(|_| conn.set_autocommit(true)) + .map_err(|e| Error::Protocol(format!("Failed to rollback transaction: {}", e))) + }) + .await + .map_err(|_| Error::WorkerCrashed)? + } + + pub(crate) async fn execute_stream( + &mut self, + sql: &str, + args: Option, + ) -> Result, Error>>, Error> { + let (tx, rx) = flume::bounded(64); + let inner = self.inner.clone(); + let sql = sql.to_string(); + spawn_blocking(move || { + let mut guard = inner.lock().unwrap(); + if let Err(e) = execute_sql(&mut guard, &sql, args, &tx) { + let _ = send_stream_result(&tx, Err(e)); + } + }) + .await + .map_err(|_| Error::WorkerCrashed)?; + Ok(rx) + } + + pub(crate) async fn prepare( + &mut self, + sql: &str, + ) -> Result<(u64, Vec, usize), Error> { + let inner = self.inner.clone(); + let sql = sql.to_string(); + spawn_blocking(move || do_prepare(&mut inner.lock().unwrap(), sql.into())) + .await + .map_err(|_| Error::WorkerCrashed)? } } @@ -46,8 +136,12 @@ impl Connection for OdbcConnection { type Options = OdbcConnectOptions; - fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { - Box::pin(async move { self.worker.shutdown().await }) + fn close(self) -> BoxFuture<'static, Result<(), Error>> { + Box::pin(async move { + // Drop connection by moving Arc and letting it fall out of scope. + drop(self); + Ok(()) + }) } fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> { @@ -55,7 +149,7 @@ impl Connection for OdbcConnection { } fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(self.worker.ping()) + Box::pin(self.ping_blocking()) } fn begin(&mut self) -> BoxFuture<'_, Result, Error>> @@ -75,3 +169,332 @@ impl Connection for OdbcConnection { false } } + +// --- Blocking helpers --- + +fn establish_connection(options: &OdbcConnectOptions) -> Result, Error> { + let env = odbc_api::environment().map_err(|e| Error::Configuration(e.to_string().into()))?; + let conn = env + .connect_with_connection_string(options.connection_string(), Default::default()) + .map_err(|e| Error::Configuration(e.to_string().into()))?; + Ok(conn) +} + +type ExecuteResult = Result, Error>; +type ExecuteSender = flume::Sender; + +fn send_stream_result( + tx: &ExecuteSender, + result: ExecuteResult, +) -> Result<(), SendError> { + tx.send(result) +} + +fn execute_sql( + conn: &mut odbc_api::Connection<'static>, + sql: &str, + args: Option, + tx: &ExecuteSender, +) -> Result<(), Error> { + let params = prepare_parameters(args); + let stmt = &mut conn.preallocate().map_err(Error::from)?; + + if let Some(mut cursor) = stmt.execute(sql, ¶ms[..])? { + handle_cursor(&mut cursor, tx); + return Ok(()); + } + + let affected = extract_rows_affected(stmt); + let _ = send_done(tx, affected); + Ok(()) +} + +fn extract_rows_affected(stmt: &mut Preallocated>) -> u64 { + let count_opt = match stmt.row_count() { + Ok(count_opt) => count_opt, + Err(e) => { + log::warn!("Failed to get ODBC row count: {}", e); + return 0; + } + }; + + let count = match count_opt { + Some(count) => count, + None => { + log::debug!("ODBC row count is not available"); + return 0; + } + }; + + let affected = match u64::try_from(count) { + Ok(count) => count, + Err(e) => { + log::warn!("Failed to convert ODBC row count to u64: {}", e); + return 0; + } + }; + affected +} + +fn prepare_parameters( + args: Option, +) -> Vec> { + let args = args.map(|a| a.values).unwrap_or_default(); + args.into_iter().map(to_param).collect() +} + +fn to_param(arg: OdbcArgumentValue) -> Box { + match arg { + OdbcArgumentValue::Int(i) => Box::new(i.into_parameter()), + OdbcArgumentValue::Float(f) => Box::new(f.into_parameter()), + OdbcArgumentValue::Text(s) => Box::new(s.into_parameter()), + OdbcArgumentValue::Bytes(b) => Box::new(b.into_parameter()), + OdbcArgumentValue::Null => Box::new(Option::::None.into_parameter()), + } +} + +fn handle_cursor(cursor: &mut C, tx: &ExecuteSender) +where + C: Cursor + ResultSetMetadata, +{ + let columns = collect_columns(cursor); + + match stream_rows(cursor, &columns, tx) { + Ok(true) => { + let _ = send_done(tx, 0); + } + Ok(false) => {} + Err(e) => { + send_error(tx, e); + } + } +} + +fn send_done(tx: &ExecuteSender, rows_affected: u64) -> Result<(), SendError> { + send_stream_result(tx, Ok(Either::Left(OdbcQueryResult { rows_affected }))) +} + +fn send_error(tx: &ExecuteSender, error: Error) { + if let Err(e) = send_stream_result(tx, Err(error)) { + log::error!("Failed to send error from ODBC blocking task: {}", e); + } +} + +fn send_row(tx: &ExecuteSender, row: OdbcRow) -> Result<(), SendError> { + send_stream_result(tx, Ok(Either::Right(row))) +} + +fn collect_columns(cursor: &mut C) -> Vec +where + C: ResultSetMetadata, +{ + let count = cursor.num_result_cols().unwrap_or(0); + (1..=count).map(|i| create_column(cursor, i as u16)).collect() +} + +fn create_column(cursor: &mut C, index: u16) -> OdbcColumn +where + C: ResultSetMetadata, +{ + let mut cd = odbc_api::ColumnDescription::default(); + let _ = cursor.describe_col(index, &mut cd); + + OdbcColumn { + name: decode_column_name(cd.name, index), + type_info: OdbcTypeInfo::new(cd.data_type), + ordinal: usize::from(index.checked_sub(1).unwrap()), + } +} + +fn decode_column_name(name_bytes: Vec, index: u16) -> String { + String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) +} + +fn stream_rows(cursor: &mut C, columns: &[OdbcColumn], tx: &ExecuteSender) -> Result +where + C: Cursor, +{ + let mut receiver_open = true; + let mut row_count = 0; + + while let Some(mut row) = cursor.next_row()? { + let values = collect_row_values(&mut row, columns)?; + let row_data = OdbcRow { + columns: columns.to_vec(), + values: values.into_iter().map(|(_, value)| value).collect(), + }; + + if send_row(tx, row_data).is_err() { + receiver_open = false; + break; + } + row_count += 1; + } + + let _ = row_count; // kept for potential logging + Ok(receiver_open) +} + +fn collect_row_values( + row: &mut CursorRow<'_>, + columns: &[OdbcColumn], +) -> Result, Error> { + columns + .iter() + .enumerate() + .map(|(i, column)| collect_column_value(row, i, column)) + .collect() +} + +fn collect_column_value( + row: &mut CursorRow<'_>, + index: usize, + column: &OdbcColumn, +) -> Result<(OdbcTypeInfo, crate::odbc::OdbcValue), Error> { + use odbc_api::DataType; + + let col_idx = (index + 1) as u16; + let type_info = column.type_info.clone(); + let data_type = type_info.data_type(); + + let value = match data_type { + DataType::TinyInt + | DataType::SmallInt + | DataType::Integer + | DataType::BigInt + | DataType::Bit => extract_int(row, col_idx, &type_info)?, + + DataType::Real => extract_float::(row, col_idx, &type_info)?, + DataType::Float { .. } | DataType::Double => extract_float::(row, col_idx, &type_info)?, + + DataType::Char { .. } + | DataType::Varchar { .. } + | DataType::LongVarchar { .. } + | DataType::WChar { .. } + | DataType::WVarchar { .. } + | DataType::WLongVarchar { .. } + | DataType::Date + | DataType::Time { .. } + | DataType::Timestamp { .. } + | DataType::Decimal { .. } + | DataType::Numeric { .. } => extract_text(row, col_idx, &type_info)?, + + DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => { + extract_binary(row, col_idx, &type_info)? + } + + DataType::Unknown | DataType::Other { .. } => match extract_text(row, col_idx, &type_info) { + Ok(v) => v, + Err(_) => extract_binary(row, col_idx, &type_info)?, + }, + }; + + Ok((type_info, value)) +} + +fn extract_int( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut nullable = Nullable::::null(); + row.get_data(col_idx, &mut nullable)?; + + let (is_null, int) = match nullable.into_opt() { + None => (true, None), + Some(v) => (false, Some(v)), + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob: None, + int, + float: None, + }) +} + +fn extract_float( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result +where + T: Into + Default, + odbc_api::Nullable: odbc_api::parameter::CElement + odbc_api::handles::CDataMut, +{ + let mut nullable = Nullable::::null(); + row.get_data(col_idx, &mut nullable)?; + + let (is_null, float) = match nullable.into_opt() { + None => (true, None), + Some(v) => (false, Some(v.into())), + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob: None, + int: None, + float, + }) +} + +fn extract_text( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut buf = Vec::new(); + let is_some = row.get_text(col_idx, &mut buf)?; + + let (is_null, text) = if !is_some { + (true, None) + } else { + match String::from_utf8(buf) { + Ok(s) => (false, Some(s)), + Err(e) => return Err(Error::Decode(e.into())), + } + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text, + blob: None, + int: None, + float: None, + }) +} + +fn extract_binary( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut buf = Vec::new(); + let is_some = row.get_binary(col_idx, &mut buf)?; + + let (is_null, blob) = if !is_some { (true, None) } else { (false, Some(buf)) }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob, + int: None, + float: None, + }) +} + +fn do_prepare( + conn: &mut odbc_api::Connection<'static>, + sql: Box, +) -> Result<(u64, Vec, usize), Error> { + let mut prepared = conn.prepare(&sql)?; + let columns = collect_columns(&mut prepared); + let params = usize::from(prepared.num_params().unwrap_or(0)); + Ok((0, columns, params)) +} diff --git a/sqlx-core/src/odbc/connection/worker.rs b/sqlx-core/src/odbc/connection/worker.rs deleted file mode 100644 index 2ed7367946..0000000000 --- a/sqlx-core/src/odbc/connection/worker.rs +++ /dev/null @@ -1,781 +0,0 @@ -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::thread; - -use flume::{SendError, TrySendError}; -use futures_channel::oneshot; - -use crate::error::Error; -use crate::odbc::{ - OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, - OdbcTypeInfo, -}; -#[allow(unused_imports)] -use crate::row::Row as SqlxRow; -use either::Either; -#[allow(unused_imports)] -use odbc_api::handles::Statement as OdbcStatementTrait; -use odbc_api::handles::StatementImpl; -use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata}; - -// Type aliases for commonly used types -type OdbcConnection = odbc_api::Connection<'static>; -type TransactionResult = Result<(), Error>; -type TransactionSender = oneshot::Sender; -type ExecuteResult = Result, Error>; -type ExecuteSender = flume::Sender; -type PrepareResult = Result<(u64, Vec, usize), Error>; -type PrepareSender = oneshot::Sender; - -#[derive(Debug)] -pub(crate) struct ConnectionWorker { - command_tx: flume::Sender, - join_handle: Option>, -} - -#[derive(Debug)] -enum Command { - Ping { - tx: oneshot::Sender<()>, - }, - Shutdown { - tx: oneshot::Sender<()>, - }, - Begin { - tx: TransactionSender, - }, - Commit { - tx: TransactionSender, - }, - Rollback { - tx: TransactionSender, - }, - Execute { - sql: Box, - args: Option, - tx: ExecuteSender, - }, - Prepare { - sql: Box, - tx: PrepareSender, - }, - GetDbmsName { - tx: oneshot::Sender>, - }, -} - -impl Drop for ConnectionWorker { - fn drop(&mut self) { - self.shutdown_sync(); - } -} - -impl ConnectionWorker { - pub async fn establish(options: OdbcConnectOptions) -> Result { - let (command_tx, command_rx) = flume::bounded(64); - let (conn_tx, conn_rx) = oneshot::channel(); - let thread = thread::Builder::new() - .name("sqlx-odbc-conn".into()) - .spawn(move || worker_thread_main(options, command_rx, conn_tx))?; - - conn_rx.await.map_err(|_| Error::WorkerCrashed)??; - Ok(ConnectionWorker { - command_tx, - join_handle: Some(thread), - }) - } - - pub(crate) async fn ping(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_command_and_await(&self.command_tx, Command::Ping { tx }, rx).await - } - - pub(crate) async fn shutdown(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_command_and_await(&self.command_tx, Command::Shutdown { tx }, rx).await - } - - pub(crate) fn shutdown_sync(&mut self) { - // Send shutdown command without waiting for response - // Use try_send to avoid any potential blocking in Drop - - if let Some(join_handle) = self.join_handle.take() { - let (mut tx, _rx) = oneshot::channel(); - while let Err(TrySendError::Full(Command::Shutdown { tx: t })) = - self.command_tx.try_send(Command::Shutdown { tx }) - { - tx = t; - log::warn!("odbc worker thread queue is full, retrying..."); - thread::sleep(std::time::Duration::from_millis(10)); - } - if let Err(e) = join_handle.join() { - let err = e.downcast_ref::(); - log::error!( - "failed to join worker thread while shutting down: {:?}", - err - ); - } - } - } - - pub(crate) async fn begin(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_transaction_command(&self.command_tx, Command::Begin { tx }, rx).await - } - - pub(crate) async fn commit(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_transaction_command(&self.command_tx, Command::Commit { tx }, rx).await - } - - pub(crate) async fn rollback(&mut self) -> Result<(), Error> { - let (tx, rx) = oneshot::channel(); - send_transaction_command(&self.command_tx, Command::Rollback { tx }, rx).await - } - - pub(crate) async fn execute_stream( - &mut self, - sql: &str, - args: Option, - ) -> Result, Error>>, Error> { - let (tx, rx) = flume::bounded(64); - self.command_tx - .send_async(Command::Execute { - sql: sql.into(), - args, - tx, - }) - .await - .map_err(|_| Error::WorkerCrashed)?; - Ok(rx) - } - - pub(crate) async fn prepare( - &mut self, - sql: &str, - ) -> Result<(u64, Vec, usize), Error> { - let (tx, rx) = oneshot::channel(); - send_command_and_await( - &self.command_tx, - Command::Prepare { - sql: sql.into(), - tx, - }, - rx, - ) - .await? - } - - pub(crate) async fn get_dbms_name(&mut self) -> Result { - let (tx, rx) = oneshot::channel(); - send_command_and_await(&self.command_tx, Command::GetDbmsName { tx }, rx).await? - } -} - -// Worker thread implementation -fn worker_thread_main( - options: OdbcConnectOptions, - command_rx: flume::Receiver, - conn_tx: oneshot::Sender>, -) { - // Establish connection - let conn = match establish_connection(&options) { - Ok(conn) => { - log::debug!("ODBC connection established successfully"); - let _ = conn_tx.send(Ok(())); - conn - } - Err(e) => { - let _ = conn_tx.send(Err(e)); - return; - } - }; - - let mut stmt_manager = StatementManager::new(&conn); - - // Process commands - while let Ok(cmd) = command_rx.recv() { - log::trace!("Processing command: {:?}", cmd); - match process_command(cmd, &conn, &mut stmt_manager) { - Ok(CommandControlFlow::Continue) => {} - Ok(CommandControlFlow::Shutdown(shutdown_tx)) => { - log::debug!("Shutting down ODBC worker thread"); - drop(stmt_manager); - drop(conn); - send_oneshot(shutdown_tx, (), "shutdown"); - break; - } - Err(()) => { - log::error!("ODBC worker error while processing command"); - } - } - } - // Channel disconnected or shutdown command received, worker thread exits -} - -fn establish_connection(options: &OdbcConnectOptions) -> Result { - // Get or create the shared ODBC environment - // This ensures thread-safe initialization and prevents concurrent environment creation issues - let env = odbc_api::environment().map_err(|e| Error::Configuration(e.to_string().into()))?; - - let conn = env - .connect_with_connection_string(options.connection_string(), Default::default()) - .map_err(|e| Error::Configuration(e.to_string().into()))?; - - Ok(conn) -} - -/// Statement manager to handle preallocated statements -struct StatementManager<'conn> { - conn: &'conn OdbcConnection, - // Reusable statement for direct execution - direct_stmt: Option>>, - // Cache of prepared statements by SQL hash - prepared_cache: HashMap>>, -} - -impl<'conn> StatementManager<'conn> { - fn new(conn: &'conn OdbcConnection) -> Self { - log::debug!("Creating new statement manager"); - Self { - conn, - direct_stmt: None, - prepared_cache: HashMap::new(), - } - } - - fn get_or_create_direct_stmt( - &mut self, - ) -> Result<&mut Preallocated>, Error> { - if self.direct_stmt.is_none() { - log::debug!("Preallocating ODBC direct statement"); - self.direct_stmt = Some(self.conn.preallocate().map_err(Error::from)?); - } - Ok(self.direct_stmt.as_mut().unwrap()) - } - - fn get_or_create_prepared( - &mut self, - sql: &str, - ) -> Result<&mut odbc_api::Prepared>, Error> { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = DefaultHasher::new(); - sql.hash(&mut hasher); - let sql_hash = hasher.finish(); - - match self.prepared_cache.entry(sql_hash) { - Entry::Vacant(e) => { - log::trace!("Preparing statement for SQL: {}", sql); - let prepared = self.conn.prepare(sql)?; - Ok(e.insert(prepared)) - } - Entry::Occupied(e) => { - log::trace!("Using prepared statement for SQL: {}", sql); - Ok(e.into_mut()) - } - } - } -} -// Helper function to send results through oneshot channels with consistent error handling -fn send_oneshot(tx: oneshot::Sender, result: T, operation: &str) { - if tx.send(result).is_err() { - log::warn!("Failed to send {} result: receiver dropped", operation); - } -} - -fn send_stream_result( - tx: &ExecuteSender, - result: ExecuteResult, -) -> Result<(), SendError> { - tx.send(result) -} - -async fn send_command_and_await( - command_tx: &flume::Sender, - cmd: Command, - rx: oneshot::Receiver, -) -> Result { - command_tx - .send_async(cmd) - .await - .map_err(|_| Error::WorkerCrashed)?; - rx.await.map_err(|_| Error::WorkerCrashed) -} - -async fn send_transaction_command( - command_tx: &flume::Sender, - cmd: Command, - rx: oneshot::Receiver, -) -> Result<(), Error> { - send_command_and_await(command_tx, cmd, rx).await??; - Ok(()) -} - -// Utility functions for transaction operations -fn execute_transaction_operation( - conn: &OdbcConnection, - operation: F, - operation_name: &str, -) -> TransactionResult -where - F: FnOnce(&OdbcConnection) -> Result<(), odbc_api::Error>, -{ - log::trace!("{} odbc transaction", operation_name); - operation(conn) - .map_err(|e| Error::Protocol(format!("Failed to {} transaction: {}", operation_name, e))) -} - -#[derive(Debug)] -enum CommandControlFlow { - Shutdown(oneshot::Sender<()>), - Continue, -} - -type CommandResult = Result; - -// Returns a shutdown tx if the command is a shutdown command -fn process_command<'conn>( - cmd: Command, - conn: &'conn OdbcConnection, - stmt_manager: &mut StatementManager<'conn>, -) -> CommandResult { - match cmd { - Command::Ping { tx } => handle_ping(conn, tx), - Command::Begin { tx } => handle_begin(conn, tx), - Command::Commit { tx } => handle_commit(conn, tx), - Command::Rollback { tx } => handle_rollback(conn, tx), - Command::Shutdown { tx } => Ok(CommandControlFlow::Shutdown(tx)), - Command::Execute { sql, args, tx } => handle_execute(stmt_manager, sql, args, tx), - Command::Prepare { sql, tx } => handle_prepare(stmt_manager, sql, tx), - Command::GetDbmsName { tx } => handle_get_dbms_name(conn, tx), - } -} - -// Command handlers -fn handle_ping(conn: &OdbcConnection, tx: oneshot::Sender<()>) -> CommandResult { - match conn.execute("SELECT 1", (), None) { - Ok(_) => send_oneshot(tx, (), "ping"), - Err(e) => log::error!("Ping failed: {}", e), - } - Ok(CommandControlFlow::Continue) -} - -fn handle_begin(conn: &OdbcConnection, tx: TransactionSender) -> CommandResult { - let result = execute_transaction_operation(conn, |c| c.set_autocommit(false), "begin"); - send_oneshot(tx, result, "begin transaction"); - Ok(CommandControlFlow::Continue) -} - -fn handle_commit(conn: &OdbcConnection, tx: TransactionSender) -> CommandResult { - let result = execute_transaction_operation( - conn, - |c| c.commit().and_then(|_| c.set_autocommit(true)), - "commit", - ); - send_oneshot(tx, result, "commit transaction"); - Ok(CommandControlFlow::Continue) -} - -fn handle_rollback(conn: &OdbcConnection, tx: TransactionSender) -> CommandResult { - let result = execute_transaction_operation( - conn, - |c| c.rollback().and_then(|_| c.set_autocommit(true)), - "rollback", - ); - send_oneshot(tx, result, "rollback transaction"); - Ok(CommandControlFlow::Continue) -} -fn handle_prepare<'conn>( - stmt_manager: &mut StatementManager<'conn>, - sql: Box, - tx: PrepareSender, -) -> CommandResult { - let result = do_prepare(stmt_manager, sql); - send_oneshot(tx, result, "prepare"); - Ok(CommandControlFlow::Continue) -} - -fn do_prepare<'conn>(stmt_manager: &mut StatementManager<'conn>, sql: Box) -> PrepareResult { - log::trace!("Preparing statement: {}", sql); - // Use the statement manager to get or create the prepared statement - let prepared = stmt_manager.get_or_create_prepared(&sql)?; - let columns = collect_columns(prepared); - let params = usize::from(prepared.num_params().unwrap_or(0)); - log::debug!( - "Prepared statement with {} columns and {} parameters", - columns.len(), - params - ); - Ok((0, columns, params)) -} - -fn handle_get_dbms_name( - conn: &OdbcConnection, - tx: oneshot::Sender>, -) -> CommandResult { - log::debug!("Getting DBMS name"); - let result = conn - .database_management_system_name() - .map_err(|e| Error::Protocol(format!("Failed to get DBMS name: {}", e))); - send_oneshot(tx, result, "DBMS name"); - Ok(CommandControlFlow::Continue) -} - -fn handle_execute<'conn>( - stmt_manager: &mut StatementManager<'conn>, - sql: Box, - args: Option, - tx: ExecuteSender, -) -> CommandResult { - log::debug!( - "Executing SQL: {}", - sql.chars().take(100).collect::() - ); - - let result = execute_sql(stmt_manager, &sql, args, &tx); - with_result_send_error(result, &tx, |_| {}); - Ok(CommandControlFlow::Continue) -} - -// SQL execution functions -fn execute_sql<'conn>( - stmt_manager: &mut StatementManager<'conn>, - sql: &str, - args: Option, - tx: &ExecuteSender, -) -> Result<(), Error> { - let params = prepare_parameters(args); - let stmt = stmt_manager.get_or_create_direct_stmt()?; - log::trace!("Starting execution of SQL: {}", sql); - - // Execute and handle result immediately to avoid borrowing conflicts - if let Some(mut cursor) = stmt.execute(sql, ¶ms[..])? { - handle_cursor(&mut cursor, tx); - return Ok(()); - } - - // Execution completed without result set, get affected rows - let affected = extract_rows_affected(stmt); - let _ = send_done(tx, affected); - Ok(()) -} - -fn extract_rows_affected(stmt: &mut Preallocated>) -> u64 { - let count_opt = match stmt.row_count() { - Ok(count_opt) => count_opt, - Err(e) => { - log::warn!("Failed to get ODBC row count: {}", e); - return 0; - } - }; - - let count = match count_opt { - Some(count) => count, - None => { - log::debug!("ODBC row count is not available"); - return 0; - } - }; - - let affected = match u64::try_from(count) { - Ok(count) => count, - Err(e) => { - log::warn!("Failed to convert ODBC row count to u64: {}", e); - return 0; - } - }; - log::trace!("ODBC statement affected {} rows", affected); - affected -} - -fn prepare_parameters( - args: Option, -) -> Vec> { - let args = args.map(|a| a.values).unwrap_or_default(); - args.into_iter().map(to_param).collect() -} - -fn to_param(arg: OdbcArgumentValue) -> Box { - match arg { - OdbcArgumentValue::Int(i) => Box::new(i.into_parameter()), - OdbcArgumentValue::Float(f) => Box::new(f.into_parameter()), - OdbcArgumentValue::Text(s) => Box::new(s.into_parameter()), - OdbcArgumentValue::Bytes(b) => Box::new(b.into_parameter()), - OdbcArgumentValue::Null => Box::new(Option::::None.into_parameter()), - } -} - -fn handle_cursor(cursor: &mut C, tx: &ExecuteSender) -where - C: Cursor + ResultSetMetadata, -{ - let columns = collect_columns(cursor); - log::trace!("Processing ODBC result set with {} columns", columns.len()); - - match stream_rows(cursor, &columns, tx) { - Ok(true) => { - log::trace!("Successfully streamed all rows"); - let _ = send_done(tx, 0); - } - Ok(false) => { - log::trace!("Row streaming stopped early (receiver closed)"); - } - Err(e) => { - send_error(tx, e); - } - } -} - -// Unified result sending functions -fn send_done(tx: &ExecuteSender, rows_affected: u64) -> Result<(), SendError> { - send_stream_result(tx, Ok(Either::Left(OdbcQueryResult { rows_affected }))) -} - -fn with_result_send_error( - result: Result, - tx: &ExecuteSender, - handler: impl FnOnce(T), -) { - match result { - Ok(result) => handler(result), - Err(error) => send_error(tx, error), - } -} - -fn send_error(tx: &ExecuteSender, error: Error) { - if let Err(e) = send_stream_result(tx, Err(error)) { - log::error!("Failed to send error from ODBC worker thread: {}", e); - } -} - -fn send_row(tx: &ExecuteSender, row: OdbcRow) -> Result<(), SendError> { - send_stream_result(tx, Ok(Either::Right(row))) -} - -// Metadata and row processing -fn collect_columns(cursor: &mut C) -> Vec -where - C: ResultSetMetadata, -{ - let count = cursor.num_result_cols().unwrap_or(0); - - (1..=count) - .map(|i| create_column(cursor, i as u16)) - .collect() -} - -fn create_column(cursor: &mut C, index: u16) -> OdbcColumn -where - C: ResultSetMetadata, -{ - let mut cd = odbc_api::ColumnDescription::default(); - let _ = cursor.describe_col(index, &mut cd); - - OdbcColumn { - name: decode_column_name(cd.name, index), - type_info: OdbcTypeInfo::new(cd.data_type), - ordinal: usize::from(index.checked_sub(1).unwrap()), - } -} - -fn decode_column_name(name_bytes: Vec, index: u16) -> String { - String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) -} - -fn stream_rows(cursor: &mut C, columns: &[OdbcColumn], tx: &ExecuteSender) -> Result -where - C: Cursor, -{ - let mut receiver_open = true; - let mut row_count = 0; - - while let Some(mut row) = cursor.next_row()? { - let values = collect_row_values(&mut row, columns)?; - let row_data = OdbcRow { - columns: columns.to_vec(), - values: values.into_iter().map(|(_, value)| value).collect(), - }; - - if send_row(tx, row_data).is_err() { - log::debug!("Receiver closed after {} rows", row_count); - receiver_open = false; - break; - } - row_count += 1; - } - - if receiver_open { - log::debug!("Streamed {} rows successfully", row_count); - } - Ok(receiver_open) -} - -fn collect_row_values( - row: &mut CursorRow<'_>, - columns: &[OdbcColumn], -) -> Result, Error> { - columns - .iter() - .enumerate() - .map(|(i, column)| collect_column_value(row, i, column)) - .collect() -} - -fn collect_column_value( - row: &mut CursorRow<'_>, - index: usize, - column: &OdbcColumn, -) -> Result<(OdbcTypeInfo, crate::odbc::OdbcValue), Error> { - use odbc_api::DataType; - - let col_idx = (index + 1) as u16; - let type_info = column.type_info.clone(); - let data_type = type_info.data_type(); - - // Extract value based on data type - let value = match data_type { - // Integer types - DataType::TinyInt - | DataType::SmallInt - | DataType::Integer - | DataType::BigInt - | DataType::Bit => extract_int(row, col_idx, &type_info)?, - - // Floating point types - DataType::Real => extract_float::(row, col_idx, &type_info)?, - DataType::Float { .. } | DataType::Double => { - extract_float::(row, col_idx, &type_info)? - } - - // String types - DataType::Char { .. } - | DataType::Varchar { .. } - | DataType::LongVarchar { .. } - | DataType::WChar { .. } - | DataType::WVarchar { .. } - | DataType::WLongVarchar { .. } - | DataType::Date - | DataType::Time { .. } - | DataType::Timestamp { .. } - | DataType::Decimal { .. } - | DataType::Numeric { .. } => extract_text(row, col_idx, &type_info)?, - - // Binary types - DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => { - extract_binary(row, col_idx, &type_info)? - } - - // Unknown types - try text first, fall back to binary - DataType::Unknown | DataType::Other { .. } => { - match extract_text(row, col_idx, &type_info) { - Ok(v) => v, - Err(_) => extract_binary(row, col_idx, &type_info)?, - } - } - }; - - Ok((type_info, value)) -} - -fn extract_int( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut nullable = Nullable::::null(); - row.get_data(col_idx, &mut nullable)?; - - let (is_null, int) = match nullable.into_opt() { - None => (true, None), - Some(v) => (false, Some(v)), - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob: None, - int, - float: None, - }) -} - -fn extract_float( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result -where - T: Into + Default, - odbc_api::Nullable: odbc_api::parameter::CElement + odbc_api::handles::CDataMut, -{ - let mut nullable = Nullable::::null(); - row.get_data(col_idx, &mut nullable)?; - - let (is_null, float) = match nullable.into_opt() { - None => (true, None), - Some(v) => (false, Some(v.into())), - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob: None, - int: None, - float, - }) -} - -fn extract_text( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut buf = Vec::new(); - let is_some = row.get_text(col_idx, &mut buf)?; - - let (is_null, text) = if !is_some { - (true, None) - } else { - match String::from_utf8(buf) { - Ok(s) => (false, Some(s)), - Err(e) => return Err(Error::Decode(e.into())), - } - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text, - blob: None, - int: None, - float: None, - }) -} - -fn extract_binary( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut buf = Vec::new(); - let is_some = row.get_binary(col_idx, &mut buf)?; - - let (is_null, blob) = if !is_some { - (true, None) - } else { - (false, Some(buf)) - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob, - int: None, - float: None, - }) -} diff --git a/sqlx-core/src/odbc/transaction.rs b/sqlx-core/src/odbc/transaction.rs index 2556c16784..aa57d73859 100644 --- a/sqlx-core/src/odbc/transaction.rs +++ b/sqlx-core/src/odbc/transaction.rs @@ -11,19 +11,19 @@ impl TransactionManager for OdbcTransactionManager { fn begin( conn: &mut ::Connection, ) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { conn.worker.begin().await }) + Box::pin(async move { conn.begin_blocking().await }) } fn commit( conn: &mut ::Connection, ) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { conn.worker.commit().await }) + Box::pin(async move { conn.commit_blocking().await }) } fn rollback( conn: &mut ::Connection, ) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { conn.worker.rollback().await }) + Box::pin(async move { conn.rollback_blocking().await }) } fn start_rollback(_conn: &mut ::Connection) { diff --git a/sqlx-rt/src/rt_async_std.rs b/sqlx-rt/src/rt_async_std.rs index aeca8541ab..2abf9cd2f2 100644 --- a/sqlx-rt/src/rt_async_std.rs +++ b/sqlx-rt/src/rt_async_std.rs @@ -2,6 +2,7 @@ pub use async_std::{ self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt, io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite, net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now, + task::spawn_blocking, }; #[cfg(unix)] diff --git a/sqlx-rt/src/rt_tokio.rs b/sqlx-rt/src/rt_tokio.rs index b1d3bc8149..5822deef57 100644 --- a/sqlx-rt/src/rt_tokio.rs +++ b/sqlx-rt/src/rt_tokio.rs @@ -1,6 +1,7 @@ pub use tokio::{ self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf, net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now, + task::spawn_blocking, time::sleep, time::timeout, }; From 0fb39afc70ee6e123d76471325f54919c58c3087 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Sep 2025 11:37:35 +0000 Subject: [PATCH 02/15] Refactor: Move ODBC blocking helpers to inner module This change reorganizes the ODBC connection implementation by moving blocking helper functions into a new `inner.rs` file. This improves code organization and maintainability. Additionally, the `prepare` method in `OdbcConnection` is renamed to `prepare_metadata` to better reflect its functionality. Co-authored-by: contact --- sqlx-core/src/any/connection/executor.rs | 10 +- sqlx-core/src/odbc/blocking.rs | 11 + sqlx-core/src/odbc/connection/executor.rs | 2 +- sqlx-core/src/odbc/connection/inner.rs | 313 +++++++++++++++++ sqlx-core/src/odbc/connection/mod.rs | 392 ++-------------------- sqlx-core/src/odbc/mod.rs | 1 + 6 files changed, 360 insertions(+), 369 deletions(-) create mode 100644 sqlx-core/src/odbc/blocking.rs create mode 100644 sqlx-core/src/odbc/connection/inner.rs diff --git a/sqlx-core/src/any/connection/executor.rs b/sqlx-core/src/any/connection/executor.rs index d49d23e543..9bd5ce06e9 100644 --- a/sqlx-core/src/any/connection/executor.rs +++ b/sqlx-core/src/any/connection/executor.rs @@ -128,7 +128,15 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { AnyConnectionKind::Mssql(conn) => conn.prepare(sql).await.map(Into::into)?, #[cfg(feature = "odbc")] - AnyConnectionKind::Odbc(conn) => conn.prepare(sql).await.map(Into::into)?, + AnyConnectionKind::Odbc(conn) => { + let (_, columns, parameters) = conn.prepare_metadata(sql).await?; + crate::odbc::OdbcStatement { + sql: sql.into(), + columns, + parameters, + } + .into() + } }) }) } diff --git a/sqlx-core/src/odbc/blocking.rs b/sqlx-core/src/odbc/blocking.rs new file mode 100644 index 0000000000..165e201a27 --- /dev/null +++ b/sqlx-core/src/odbc/blocking.rs @@ -0,0 +1,11 @@ +use crate::error::Error; +use sqlx_rt::spawn_blocking; + +pub async fn run_blocking(f: F) -> Result +where + R: Send + 'static, + F: FnOnce() -> Result + Send + 'static, +{ + let res = spawn_blocking(f).await.map_err(|_| Error::WorkerCrashed)?; + res +} diff --git a/sqlx-core/src/odbc/connection/executor.rs b/sqlx-core/src/odbc/connection/executor.rs index b953ef0969..0010a2b8e3 100644 --- a/sqlx-core/src/odbc/connection/executor.rs +++ b/sqlx-core/src/odbc/connection/executor.rs @@ -60,7 +60,7 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection { 'c: 'e, { Box::pin(async move { - let (_, columns, parameters) = self.prepare(sql).await?; + let (_, columns, parameters) = self.prepare_metadata(sql).await?; Ok(OdbcStatement { sql: Cow::Borrowed(sql), columns, diff --git a/sqlx-core/src/odbc/connection/inner.rs b/sqlx-core/src/odbc/connection/inner.rs new file mode 100644 index 0000000000..0eaa7ade03 --- /dev/null +++ b/sqlx-core/src/odbc/connection/inner.rs @@ -0,0 +1,313 @@ +use crate::error::Error; +use crate::odbc::{OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult, OdbcRow, OdbcTypeInfo}; +use either::Either; +use flume::{SendError, Sender}; +use odbc_api::handles::StatementImpl; +use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata}; + +pub type OdbcConn = odbc_api::Connection<'static>; +pub type ExecuteResult = Result, Error>; +pub type ExecuteSender = Sender; + +pub fn establish_connection(options: &crate::odbc::OdbcConnectOptions) -> Result { + let env = odbc_api::environment().map_err(|e| Error::Configuration(e.to_string().into()))?; + let conn = env + .connect_with_connection_string(options.connection_string(), Default::default()) + .map_err(|e| Error::Configuration(e.to_string().into()))?; + Ok(conn) +} + +pub fn execute_sql( + conn: &mut OdbcConn, + sql: &str, + args: Option, + tx: &ExecuteSender, +) -> Result<(), Error> { + let params = prepare_parameters(args); + let stmt = &mut conn.preallocate().map_err(Error::from)?; + + if let Some(mut cursor) = stmt.execute(sql, ¶ms[..])? { + handle_cursor(&mut cursor, tx); + return Ok(()); + } + + let affected = extract_rows_affected(stmt); + let _ = send_done(tx, affected); + Ok(()) +} + +fn extract_rows_affected(stmt: &mut Preallocated>) -> u64 { + let count_opt = match stmt.row_count() { + Ok(count_opt) => count_opt, + Err(_) => { + return 0; + } + }; + + let count = match count_opt { + Some(count) => count, + None => { + return 0; + } + }; + + u64::try_from(count).unwrap_or_default() +} + +fn prepare_parameters( + args: Option, +) -> Vec> { + let args = args.map(|a| a.values).unwrap_or_default(); + args.into_iter().map(to_param).collect() +} + +fn to_param(arg: OdbcArgumentValue) -> Box { + match arg { + OdbcArgumentValue::Int(i) => Box::new(i.into_parameter()), + OdbcArgumentValue::Float(f) => Box::new(f.into_parameter()), + OdbcArgumentValue::Text(s) => Box::new(s.into_parameter()), + OdbcArgumentValue::Bytes(b) => Box::new(b.into_parameter()), + OdbcArgumentValue::Null => Box::new(Option::::None.into_parameter()), + } +} + +fn handle_cursor(cursor: &mut C, tx: &ExecuteSender) +where + C: Cursor + ResultSetMetadata, +{ + let columns = collect_columns(cursor); + + match stream_rows(cursor, &columns, tx) { + Ok(true) => { + let _ = send_done(tx, 0); + } + Ok(false) => {} + Err(e) => { + send_error(tx, e); + } + } +} + +fn send_done(tx: &ExecuteSender, rows_affected: u64) -> Result<(), SendError> { + tx.send(Ok(Either::Left(OdbcQueryResult { rows_affected }))) +} + +fn send_error(tx: &ExecuteSender, error: Error) { + let _ = tx.send(Err(error)); +} + +fn send_row(tx: &ExecuteSender, row: OdbcRow) -> Result<(), SendError> { + tx.send(Ok(Either::Right(row))) +} + +fn collect_columns(cursor: &mut C) -> Vec +where + C: ResultSetMetadata, +{ + let count = cursor.num_result_cols().unwrap_or(0); + (1..=count).map(|i| create_column(cursor, i as u16)).collect() +} + +fn create_column(cursor: &mut C, index: u16) -> OdbcColumn +where + C: ResultSetMetadata, +{ + let mut cd = odbc_api::ColumnDescription::default(); + let _ = cursor.describe_col(index, &mut cd); + + OdbcColumn { + name: decode_column_name(cd.name, index), + type_info: OdbcTypeInfo::new(cd.data_type), + ordinal: usize::from(index.checked_sub(1).unwrap()), + } +} + +fn decode_column_name(name_bytes: Vec, index: u16) -> String { + String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) +} + +fn stream_rows(cursor: &mut C, columns: &[OdbcColumn], tx: &ExecuteSender) -> Result +where + C: Cursor, +{ + let mut receiver_open = true; + + while let Some(mut row) = cursor.next_row()? { + let values = collect_row_values(&mut row, columns)?; + let row_data = OdbcRow { + columns: columns.to_vec(), + values: values.into_iter().map(|(_, value)| value).collect(), + }; + + if send_row(tx, row_data).is_err() { + receiver_open = false; + break; + } + } + Ok(receiver_open) +} + +fn collect_row_values( + row: &mut CursorRow<'_>, + columns: &[OdbcColumn], +) -> Result, Error> { + columns + .iter() + .enumerate() + .map(|(i, column)| collect_column_value(row, i, column)) + .collect() +} + +fn collect_column_value( + row: &mut CursorRow<'_>, + index: usize, + column: &OdbcColumn, +) -> Result<(OdbcTypeInfo, crate::odbc::OdbcValue), Error> { + use odbc_api::DataType; + + let col_idx = (index + 1) as u16; + let type_info = column.type_info.clone(); + let data_type = type_info.data_type(); + + let value = match data_type { + DataType::TinyInt + | DataType::SmallInt + | DataType::Integer + | DataType::BigInt + | DataType::Bit => extract_int(row, col_idx, &type_info)?, + + DataType::Real => extract_float::(row, col_idx, &type_info)?, + DataType::Float { .. } | DataType::Double => extract_float::(row, col_idx, &type_info)?, + + DataType::Char { .. } + | DataType::Varchar { .. } + | DataType::LongVarchar { .. } + | DataType::WChar { .. } + | DataType::WVarchar { .. } + | DataType::WLongVarchar { .. } + | DataType::Date + | DataType::Time { .. } + | DataType::Timestamp { .. } + | DataType::Decimal { .. } + | DataType::Numeric { .. } => extract_text(row, col_idx, &type_info)?, + + DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => { + extract_binary(row, col_idx, &type_info)? + } + + DataType::Unknown | DataType::Other { .. } => match extract_text(row, col_idx, &type_info) { + Ok(v) => v, + Err(_) => extract_binary(row, col_idx, &type_info)?, + }, + }; + + Ok((type_info, value)) +} + +fn extract_int( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut nullable = Nullable::::null(); + row.get_data(col_idx, &mut nullable)?; + + let (is_null, int) = match nullable.into_opt() { + None => (true, None), + Some(v) => (false, Some(v)), + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob: None, + int, + float: None, + }) +} + +fn extract_float( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result +where + T: Into + Default, + odbc_api::Nullable: odbc_api::parameter::CElement + odbc_api::handles::CDataMut, +{ + let mut nullable = Nullable::::null(); + row.get_data(col_idx, &mut nullable)?; + + let (is_null, float) = match nullable.into_opt() { + None => (true, None), + Some(v) => (false, Some(v.into())), + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob: None, + int: None, + float, + }) +} + +fn extract_text( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut buf = Vec::new(); + let is_some = row.get_text(col_idx, &mut buf)?; + + let (is_null, text) = if !is_some { + (true, None) + } else { + match String::from_utf8(buf) { + Ok(s) => (false, Some(s)), + Err(e) => return Err(Error::Decode(e.into())), + } + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text, + blob: None, + int: None, + float: None, + }) +} + +fn extract_binary( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut buf = Vec::new(); + let is_some = row.get_binary(col_idx, &mut buf)?; + + let (is_null, blob) = if !is_some { (true, None) } else { (false, Some(buf)) }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob, + int: None, + float: None, + }) +} + +pub fn do_prepare( + conn: &mut OdbcConn, + sql: Box, +) -> Result<(u64, Vec, usize), Error> { + let mut prepared = conn.prepare(&sql)?; + let columns = collect_columns(&mut prepared); + let params = usize::from(prepared.num_params().unwrap_or(0)); + Ok((0, columns, params)) +} + diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 3ce6bf7029..63f13654f3 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -1,14 +1,14 @@ use crate::connection::{Connection, LogSettings}; use crate::error::Error; -use crate::odbc::{Odbc, OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, OdbcTypeInfo}; +use crate::odbc::{Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow}; use crate::transaction::Transaction; use either::Either; -use flume::SendError; +use crate::odbc::blocking::run_blocking; +mod inner; +use inner::{do_prepare, establish_connection, execute_sql, OdbcConn}; use futures_core::future::BoxFuture; use futures_util::future; -use odbc_api::handles::StatementImpl; -use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata}; -use sqlx_rt::spawn_blocking; +// no direct spawn_blocking here; use run_blocking helper use std::sync::{Arc, Mutex}; mod executor; @@ -19,18 +19,16 @@ mod executor; /// thread-pool via `spawn_blocking` and synchronize access with a mutex. #[derive(Debug)] pub struct OdbcConnection { - pub(crate) inner: Arc>>, + pub(crate) inner: Arc>, pub(crate) log_settings: LogSettings, } impl OdbcConnection { pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result { - let conn = spawn_blocking({ + let conn = run_blocking({ let options = options.clone(); move || establish_connection(&options) - }) - .await - .map_err(|_| Error::WorkerCrashed)??; + }).await?; Ok(Self { inner: Arc::new(Mutex::new(conn)), @@ -42,62 +40,52 @@ impl OdbcConnection { /// connection is talking to as reported by the ODBC driver. pub async fn dbms_name(&mut self) -> Result { let inner = self.inner.clone(); - spawn_blocking(move || { + run_blocking(move || { let conn = inner.lock().unwrap(); conn.database_management_system_name() .map_err(|e| Error::Protocol(format!("Failed to get DBMS name: {}", e))) - }) - .await - .map_err(|_| Error::WorkerCrashed)? + }).await } pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> { let inner = self.inner.clone(); - spawn_blocking(move || { + run_blocking(move || { let conn = inner.lock().unwrap(); let res = conn.execute("SELECT 1", (), None); match res { Ok(_) => Ok(()), Err(e) => Err(Error::Protocol(format!("Ping failed: {}", e))), } - }) - .await - .map_err(|_| Error::WorkerCrashed)? + }).await } pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> { let inner = self.inner.clone(); - spawn_blocking(move || { + run_blocking(move || { let conn = inner.lock().unwrap(); conn.set_autocommit(false) .map_err(|e| Error::Protocol(format!("Failed to begin transaction: {}", e))) - }) - .await - .map_err(|_| Error::WorkerCrashed)? + }).await } pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> { let inner = self.inner.clone(); - spawn_blocking(move || { + run_blocking(move || { let conn = inner.lock().unwrap(); conn.commit() .and_then(|_| conn.set_autocommit(true)) .map_err(|e| Error::Protocol(format!("Failed to commit transaction: {}", e))) - }) - .await - .map_err(|_| Error::WorkerCrashed)? + }).await } pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> { let inner = self.inner.clone(); - spawn_blocking(move || { + run_blocking(move || { let conn = inner.lock().unwrap(); conn.rollback() .and_then(|_| conn.set_autocommit(true)) .map_err(|e| Error::Protocol(format!("Failed to rollback transaction: {}", e))) - }) - .await - .map_err(|_| Error::WorkerCrashed)? + }).await } pub(crate) async fn execute_stream( @@ -108,26 +96,23 @@ impl OdbcConnection { let (tx, rx) = flume::bounded(64); let inner = self.inner.clone(); let sql = sql.to_string(); - spawn_blocking(move || { + run_blocking(move || { let mut guard = inner.lock().unwrap(); if let Err(e) = execute_sql(&mut guard, &sql, args, &tx) { - let _ = send_stream_result(&tx, Err(e)); + let _ = tx.send(Err(e)); } - }) - .await - .map_err(|_| Error::WorkerCrashed)?; + Ok(()) + }).await?; Ok(rx) } - pub(crate) async fn prepare( + pub(crate) async fn prepare_metadata( &mut self, sql: &str, ) -> Result<(u64, Vec, usize), Error> { let inner = self.inner.clone(); let sql = sql.to_string(); - spawn_blocking(move || do_prepare(&mut inner.lock().unwrap(), sql.into())) - .await - .map_err(|_| Error::WorkerCrashed)? + run_blocking(move || do_prepare(&mut inner.lock().unwrap(), sql.into())).await } } @@ -170,331 +155,4 @@ impl Connection for OdbcConnection { } } -// --- Blocking helpers --- - -fn establish_connection(options: &OdbcConnectOptions) -> Result, Error> { - let env = odbc_api::environment().map_err(|e| Error::Configuration(e.to_string().into()))?; - let conn = env - .connect_with_connection_string(options.connection_string(), Default::default()) - .map_err(|e| Error::Configuration(e.to_string().into()))?; - Ok(conn) -} - -type ExecuteResult = Result, Error>; -type ExecuteSender = flume::Sender; - -fn send_stream_result( - tx: &ExecuteSender, - result: ExecuteResult, -) -> Result<(), SendError> { - tx.send(result) -} - -fn execute_sql( - conn: &mut odbc_api::Connection<'static>, - sql: &str, - args: Option, - tx: &ExecuteSender, -) -> Result<(), Error> { - let params = prepare_parameters(args); - let stmt = &mut conn.preallocate().map_err(Error::from)?; - - if let Some(mut cursor) = stmt.execute(sql, ¶ms[..])? { - handle_cursor(&mut cursor, tx); - return Ok(()); - } - - let affected = extract_rows_affected(stmt); - let _ = send_done(tx, affected); - Ok(()) -} - -fn extract_rows_affected(stmt: &mut Preallocated>) -> u64 { - let count_opt = match stmt.row_count() { - Ok(count_opt) => count_opt, - Err(e) => { - log::warn!("Failed to get ODBC row count: {}", e); - return 0; - } - }; - - let count = match count_opt { - Some(count) => count, - None => { - log::debug!("ODBC row count is not available"); - return 0; - } - }; - - let affected = match u64::try_from(count) { - Ok(count) => count, - Err(e) => { - log::warn!("Failed to convert ODBC row count to u64: {}", e); - return 0; - } - }; - affected -} - -fn prepare_parameters( - args: Option, -) -> Vec> { - let args = args.map(|a| a.values).unwrap_or_default(); - args.into_iter().map(to_param).collect() -} - -fn to_param(arg: OdbcArgumentValue) -> Box { - match arg { - OdbcArgumentValue::Int(i) => Box::new(i.into_parameter()), - OdbcArgumentValue::Float(f) => Box::new(f.into_parameter()), - OdbcArgumentValue::Text(s) => Box::new(s.into_parameter()), - OdbcArgumentValue::Bytes(b) => Box::new(b.into_parameter()), - OdbcArgumentValue::Null => Box::new(Option::::None.into_parameter()), - } -} - -fn handle_cursor(cursor: &mut C, tx: &ExecuteSender) -where - C: Cursor + ResultSetMetadata, -{ - let columns = collect_columns(cursor); - - match stream_rows(cursor, &columns, tx) { - Ok(true) => { - let _ = send_done(tx, 0); - } - Ok(false) => {} - Err(e) => { - send_error(tx, e); - } - } -} - -fn send_done(tx: &ExecuteSender, rows_affected: u64) -> Result<(), SendError> { - send_stream_result(tx, Ok(Either::Left(OdbcQueryResult { rows_affected }))) -} - -fn send_error(tx: &ExecuteSender, error: Error) { - if let Err(e) = send_stream_result(tx, Err(error)) { - log::error!("Failed to send error from ODBC blocking task: {}", e); - } -} - -fn send_row(tx: &ExecuteSender, row: OdbcRow) -> Result<(), SendError> { - send_stream_result(tx, Ok(Either::Right(row))) -} - -fn collect_columns(cursor: &mut C) -> Vec -where - C: ResultSetMetadata, -{ - let count = cursor.num_result_cols().unwrap_or(0); - (1..=count).map(|i| create_column(cursor, i as u16)).collect() -} - -fn create_column(cursor: &mut C, index: u16) -> OdbcColumn -where - C: ResultSetMetadata, -{ - let mut cd = odbc_api::ColumnDescription::default(); - let _ = cursor.describe_col(index, &mut cd); - - OdbcColumn { - name: decode_column_name(cd.name, index), - type_info: OdbcTypeInfo::new(cd.data_type), - ordinal: usize::from(index.checked_sub(1).unwrap()), - } -} - -fn decode_column_name(name_bytes: Vec, index: u16) -> String { - String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) -} - -fn stream_rows(cursor: &mut C, columns: &[OdbcColumn], tx: &ExecuteSender) -> Result -where - C: Cursor, -{ - let mut receiver_open = true; - let mut row_count = 0; - - while let Some(mut row) = cursor.next_row()? { - let values = collect_row_values(&mut row, columns)?; - let row_data = OdbcRow { - columns: columns.to_vec(), - values: values.into_iter().map(|(_, value)| value).collect(), - }; - - if send_row(tx, row_data).is_err() { - receiver_open = false; - break; - } - row_count += 1; - } - - let _ = row_count; // kept for potential logging - Ok(receiver_open) -} - -fn collect_row_values( - row: &mut CursorRow<'_>, - columns: &[OdbcColumn], -) -> Result, Error> { - columns - .iter() - .enumerate() - .map(|(i, column)| collect_column_value(row, i, column)) - .collect() -} - -fn collect_column_value( - row: &mut CursorRow<'_>, - index: usize, - column: &OdbcColumn, -) -> Result<(OdbcTypeInfo, crate::odbc::OdbcValue), Error> { - use odbc_api::DataType; - - let col_idx = (index + 1) as u16; - let type_info = column.type_info.clone(); - let data_type = type_info.data_type(); - - let value = match data_type { - DataType::TinyInt - | DataType::SmallInt - | DataType::Integer - | DataType::BigInt - | DataType::Bit => extract_int(row, col_idx, &type_info)?, - - DataType::Real => extract_float::(row, col_idx, &type_info)?, - DataType::Float { .. } | DataType::Double => extract_float::(row, col_idx, &type_info)?, - - DataType::Char { .. } - | DataType::Varchar { .. } - | DataType::LongVarchar { .. } - | DataType::WChar { .. } - | DataType::WVarchar { .. } - | DataType::WLongVarchar { .. } - | DataType::Date - | DataType::Time { .. } - | DataType::Timestamp { .. } - | DataType::Decimal { .. } - | DataType::Numeric { .. } => extract_text(row, col_idx, &type_info)?, - - DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => { - extract_binary(row, col_idx, &type_info)? - } - - DataType::Unknown | DataType::Other { .. } => match extract_text(row, col_idx, &type_info) { - Ok(v) => v, - Err(_) => extract_binary(row, col_idx, &type_info)?, - }, - }; - - Ok((type_info, value)) -} - -fn extract_int( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut nullable = Nullable::::null(); - row.get_data(col_idx, &mut nullable)?; - - let (is_null, int) = match nullable.into_opt() { - None => (true, None), - Some(v) => (false, Some(v)), - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob: None, - int, - float: None, - }) -} - -fn extract_float( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result -where - T: Into + Default, - odbc_api::Nullable: odbc_api::parameter::CElement + odbc_api::handles::CDataMut, -{ - let mut nullable = Nullable::::null(); - row.get_data(col_idx, &mut nullable)?; - - let (is_null, float) = match nullable.into_opt() { - None => (true, None), - Some(v) => (false, Some(v.into())), - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob: None, - int: None, - float, - }) -} - -fn extract_text( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut buf = Vec::new(); - let is_some = row.get_text(col_idx, &mut buf)?; - - let (is_null, text) = if !is_some { - (true, None) - } else { - match String::from_utf8(buf) { - Ok(s) => (false, Some(s)), - Err(e) => return Err(Error::Decode(e.into())), - } - }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text, - blob: None, - int: None, - float: None, - }) -} - -fn extract_binary( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut buf = Vec::new(); - let is_some = row.get_binary(col_idx, &mut buf)?; - - let (is_null, blob) = if !is_some { (true, None) } else { (false, Some(buf)) }; - - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob, - int: None, - float: None, - }) -} - -fn do_prepare( - conn: &mut odbc_api::Connection<'static>, - sql: Box, -) -> Result<(u64, Vec, usize), Error> { - let mut prepared = conn.prepare(&sql)?; - let columns = collect_columns(&mut prepared); - let params = usize::from(prepared.num_params().unwrap_or(0)); - Ok((0, columns, params)) -} +// moved helpers to connection/inner.rs diff --git a/sqlx-core/src/odbc/mod.rs b/sqlx-core/src/odbc/mod.rs index da41adb1e9..65c795bd53 100644 --- a/sqlx-core/src/odbc/mod.rs +++ b/sqlx-core/src/odbc/mod.rs @@ -36,6 +36,7 @@ mod transaction; mod type_info; pub mod types; mod value; +mod blocking; pub use arguments::{OdbcArgumentValue, OdbcArguments}; pub use column::OdbcColumn; From 5014661d2bef137d4702b95a6b3904eb9a71fa07 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Sep 2025 12:25:20 +0000 Subject: [PATCH 03/15] Refactor: Improve async runtime compatibility for ODBC Co-authored-by: contact --- sqlx-core/src/odbc/blocking.rs | 12 +++++++++-- sqlx-core/src/odbc/connection/inner.rs | 29 ++++++++++++++++++-------- sqlx-core/src/odbc/connection/mod.rs | 25 ++++++++++++++-------- sqlx-core/src/odbc/mod.rs | 2 +- sqlx-rt/src/rt_async_std.rs | 4 ++-- sqlx-rt/src/rt_tokio.rs | 5 ++--- 6 files changed, 51 insertions(+), 26 deletions(-) diff --git a/sqlx-core/src/odbc/blocking.rs b/sqlx-core/src/odbc/blocking.rs index 165e201a27..b25b657139 100644 --- a/sqlx-core/src/odbc/blocking.rs +++ b/sqlx-core/src/odbc/blocking.rs @@ -6,6 +6,14 @@ where R: Send + 'static, F: FnOnce() -> Result + Send + 'static, { - let res = spawn_blocking(f).await.map_err(|_| Error::WorkerCrashed)?; - res + #[cfg(feature = "_rt-tokio")] + { + let join_result = spawn_blocking(f).await.map_err(|_| Error::WorkerCrashed)?; + join_result + } + + #[cfg(feature = "_rt-async-std")] + { + spawn_blocking(f).await + } } diff --git a/sqlx-core/src/odbc/connection/inner.rs b/sqlx-core/src/odbc/connection/inner.rs index 0eaa7ade03..9d76385171 100644 --- a/sqlx-core/src/odbc/connection/inner.rs +++ b/sqlx-core/src/odbc/connection/inner.rs @@ -1,5 +1,7 @@ use crate::error::Error; -use crate::odbc::{OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult, OdbcRow, OdbcTypeInfo}; +use crate::odbc::{ + OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult, OdbcRow, OdbcTypeInfo, +}; use either::Either; use flume::{SendError, Sender}; use odbc_api::handles::StatementImpl; @@ -105,7 +107,9 @@ where C: ResultSetMetadata, { let count = cursor.num_result_cols().unwrap_or(0); - (1..=count).map(|i| create_column(cursor, i as u16)).collect() + (1..=count) + .map(|i| create_column(cursor, i as u16)) + .collect() } fn create_column(cursor: &mut C, index: u16) -> OdbcColumn @@ -177,7 +181,9 @@ fn collect_column_value( | DataType::Bit => extract_int(row, col_idx, &type_info)?, DataType::Real => extract_float::(row, col_idx, &type_info)?, - DataType::Float { .. } | DataType::Double => extract_float::(row, col_idx, &type_info)?, + DataType::Float { .. } | DataType::Double => { + extract_float::(row, col_idx, &type_info)? + } DataType::Char { .. } | DataType::Varchar { .. } @@ -195,10 +201,12 @@ fn collect_column_value( extract_binary(row, col_idx, &type_info)? } - DataType::Unknown | DataType::Other { .. } => match extract_text(row, col_idx, &type_info) { - Ok(v) => v, - Err(_) => extract_binary(row, col_idx, &type_info)?, - }, + DataType::Unknown | DataType::Other { .. } => { + match extract_text(row, col_idx, &type_info) { + Ok(v) => v, + Err(_) => extract_binary(row, col_idx, &type_info)?, + } + } }; Ok((type_info, value)) @@ -289,7 +297,11 @@ fn extract_binary( let mut buf = Vec::new(); let is_some = row.get_binary(col_idx, &mut buf)?; - let (is_null, blob) = if !is_some { (true, None) } else { (false, Some(buf)) }; + let (is_null, blob) = if !is_some { + (true, None) + } else { + (false, Some(buf)) + }; Ok(crate::odbc::OdbcValue { type_info: type_info.clone(), @@ -310,4 +322,3 @@ pub fn do_prepare( let params = usize::from(prepared.num_params().unwrap_or(0)); Ok((0, columns, params)) } - diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 63f13654f3..e317eb3594 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -1,13 +1,13 @@ use crate::connection::{Connection, LogSettings}; use crate::error::Error; +use crate::odbc::blocking::run_blocking; use crate::odbc::{Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow}; use crate::transaction::Transaction; use either::Either; -use crate::odbc::blocking::run_blocking; mod inner; -use inner::{do_prepare, establish_connection, execute_sql, OdbcConn}; use futures_core::future::BoxFuture; use futures_util::future; +use inner::{do_prepare, establish_connection, execute_sql, OdbcConn}; // no direct spawn_blocking here; use run_blocking helper use std::sync::{Arc, Mutex}; @@ -28,7 +28,8 @@ impl OdbcConnection { let conn = run_blocking({ let options = options.clone(); move || establish_connection(&options) - }).await?; + }) + .await?; Ok(Self { inner: Arc::new(Mutex::new(conn)), @@ -44,7 +45,8 @@ impl OdbcConnection { let conn = inner.lock().unwrap(); conn.database_management_system_name() .map_err(|e| Error::Protocol(format!("Failed to get DBMS name: {}", e))) - }).await + }) + .await } pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> { @@ -56,7 +58,8 @@ impl OdbcConnection { Ok(_) => Ok(()), Err(e) => Err(Error::Protocol(format!("Ping failed: {}", e))), } - }).await + }) + .await } pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> { @@ -65,7 +68,8 @@ impl OdbcConnection { let conn = inner.lock().unwrap(); conn.set_autocommit(false) .map_err(|e| Error::Protocol(format!("Failed to begin transaction: {}", e))) - }).await + }) + .await } pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> { @@ -75,7 +79,8 @@ impl OdbcConnection { conn.commit() .and_then(|_| conn.set_autocommit(true)) .map_err(|e| Error::Protocol(format!("Failed to commit transaction: {}", e))) - }).await + }) + .await } pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> { @@ -85,7 +90,8 @@ impl OdbcConnection { conn.rollback() .and_then(|_| conn.set_autocommit(true)) .map_err(|e| Error::Protocol(format!("Failed to rollback transaction: {}", e))) - }).await + }) + .await } pub(crate) async fn execute_stream( @@ -102,7 +108,8 @@ impl OdbcConnection { let _ = tx.send(Err(e)); } Ok(()) - }).await?; + }) + .await?; Ok(rx) } diff --git a/sqlx-core/src/odbc/mod.rs b/sqlx-core/src/odbc/mod.rs index 65c795bd53..b4b4920857 100644 --- a/sqlx-core/src/odbc/mod.rs +++ b/sqlx-core/src/odbc/mod.rs @@ -24,6 +24,7 @@ use crate::executor::Executor; mod arguments; +mod blocking; mod column; mod connection; mod database; @@ -36,7 +37,6 @@ mod transaction; mod type_info; pub mod types; mod value; -mod blocking; pub use arguments::{OdbcArgumentValue, OdbcArguments}; pub use column::OdbcColumn; diff --git a/sqlx-rt/src/rt_async_std.rs b/sqlx-rt/src/rt_async_std.rs index 2abf9cd2f2..e8ccb49849 100644 --- a/sqlx-rt/src/rt_async_std.rs +++ b/sqlx-rt/src/rt_async_std.rs @@ -1,8 +1,8 @@ pub use async_std::{ self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt, io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite, - net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now, - task::spawn_blocking, + net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::spawn_blocking, + task::yield_now, }; #[cfg(unix)] diff --git a/sqlx-rt/src/rt_tokio.rs b/sqlx-rt/src/rt_tokio.rs index 5822deef57..855ff6269f 100644 --- a/sqlx-rt/src/rt_tokio.rs +++ b/sqlx-rt/src/rt_tokio.rs @@ -1,8 +1,7 @@ pub use tokio::{ self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf, - net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now, - task::spawn_blocking, - time::sleep, time::timeout, + net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::spawn_blocking, + task::yield_now, time::sleep, time::timeout, }; #[cfg(unix)] From c2fefbce72f4d59db22e517fc66ae88b17244bba Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Sep 2025 12:30:25 +0000 Subject: [PATCH 04/15] Refactor: Use with_conn helper for OdbcConnection methods This change introduces a `with_conn` helper method to abstract away the common pattern of acquiring and releasing the connection lock. It also adds an `odbc_err` helper for consistent error handling. Co-authored-by: contact --- sqlx-core/src/odbc/connection/mod.rs | 78 +++++++++++++++------------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index e317eb3594..424efa0aaa 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -24,6 +24,25 @@ pub struct OdbcConnection { } impl OdbcConnection { + #[inline] + async fn with_conn(&self, f: F) -> Result + where + T: Send + 'static, + F: FnOnce(&mut OdbcConn) -> Result + Send + 'static, + { + let inner = self.inner.clone(); + run_blocking(move || { + let mut conn = inner.lock().unwrap(); + f(&mut conn) + }) + .await + } + + #[inline] + fn odbc_err(res: Result, ctx: &'static str) -> Result { + res.map_err(|e| Error::Protocol(format!("{}: {}", ctx, e))) + } + pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result { let conn = run_blocking({ let options = options.clone(); @@ -40,56 +59,44 @@ impl OdbcConnection { /// Returns the name of the actual Database Management System (DBMS) this /// connection is talking to as reported by the ODBC driver. pub async fn dbms_name(&mut self) -> Result { - let inner = self.inner.clone(); - run_blocking(move || { - let conn = inner.lock().unwrap(); - conn.database_management_system_name() - .map_err(|e| Error::Protocol(format!("Failed to get DBMS name: {}", e))) + self.with_conn(|conn| { + Self::odbc_err( + conn.database_management_system_name(), + "Failed to get DBMS name", + ) }) .await } pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> { - let inner = self.inner.clone(); - run_blocking(move || { - let conn = inner.lock().unwrap(); - let res = conn.execute("SELECT 1", (), None); - match res { - Ok(_) => Ok(()), - Err(e) => Err(Error::Protocol(format!("Ping failed: {}", e))), - } + self.with_conn(|conn| { + Self::odbc_err(conn.execute("SELECT 1", (), None), "Ping failed")?; + Ok(()) }) .await } pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> { - let inner = self.inner.clone(); - run_blocking(move || { - let conn = inner.lock().unwrap(); - conn.set_autocommit(false) - .map_err(|e| Error::Protocol(format!("Failed to begin transaction: {}", e))) + self.with_conn(|conn| { + Self::odbc_err(conn.set_autocommit(false), "Failed to begin transaction") }) .await } pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> { - let inner = self.inner.clone(); - run_blocking(move || { - let conn = inner.lock().unwrap(); - conn.commit() - .and_then(|_| conn.set_autocommit(true)) - .map_err(|e| Error::Protocol(format!("Failed to commit transaction: {}", e))) + self.with_conn(|conn| { + Self::odbc_err(conn.commit(), "Failed to commit transaction")?; + Self::odbc_err(conn.set_autocommit(true), "Failed to commit transaction")?; + Ok(()) }) .await } pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> { - let inner = self.inner.clone(); - run_blocking(move || { - let conn = inner.lock().unwrap(); - conn.rollback() - .and_then(|_| conn.set_autocommit(true)) - .map_err(|e| Error::Protocol(format!("Failed to rollback transaction: {}", e))) + self.with_conn(|conn| { + Self::odbc_err(conn.rollback(), "Failed to rollback transaction")?; + Self::odbc_err(conn.set_autocommit(true), "Failed to rollback transaction")?; + Ok(()) }) .await } @@ -100,11 +107,10 @@ impl OdbcConnection { args: Option, ) -> Result, Error>>, Error> { let (tx, rx) = flume::bounded(64); - let inner = self.inner.clone(); let sql = sql.to_string(); - run_blocking(move || { - let mut guard = inner.lock().unwrap(); - if let Err(e) = execute_sql(&mut guard, &sql, args, &tx) { + let args_move = args; + self.with_conn(move |conn| { + if let Err(e) = execute_sql(conn, &sql, args_move, &tx) { let _ = tx.send(Err(e)); } Ok(()) @@ -117,9 +123,9 @@ impl OdbcConnection { &mut self, sql: &str, ) -> Result<(u64, Vec, usize), Error> { - let inner = self.inner.clone(); let sql = sql.to_string(); - run_blocking(move || do_prepare(&mut inner.lock().unwrap(), sql.into())).await + self.with_conn(move |conn| do_prepare(conn, sql.into())) + .await } } From 9e71ab3074092913bfe5affc9dbd40a049b92c3b Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Sep 2025 12:41:15 +0000 Subject: [PATCH 05/15] Refactor: Use `with_conn_map` for blocking ODBC operations Co-authored-by: contact --- sqlx-core/src/odbc/connection/mod.rs | 44 +++++++++++++++------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 424efa0aaa..7a157fdb91 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -39,8 +39,18 @@ impl OdbcConnection { } #[inline] - fn odbc_err(res: Result, ctx: &'static str) -> Result { - res.map_err(|e| Error::Protocol(format!("{}: {}", ctx, e))) + async fn with_conn_map(&self, ctx: &'static str, f: F) -> Result + where + T: Send + 'static, + E: std::fmt::Display, + F: FnOnce(&mut OdbcConn) -> Result + Send + 'static, + { + let inner = self.inner.clone(); + run_blocking(move || { + let mut conn = inner.lock().unwrap(); + f(&mut conn).map_err(|e| Error::Protocol(format!("{}: {}", ctx, e))) + }) + .await } pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result { @@ -59,44 +69,38 @@ impl OdbcConnection { /// Returns the name of the actual Database Management System (DBMS) this /// connection is talking to as reported by the ODBC driver. pub async fn dbms_name(&mut self) -> Result { - self.with_conn(|conn| { - Self::odbc_err( - conn.database_management_system_name(), - "Failed to get DBMS name", - ) + self.with_conn_map::<_, _, _>("Failed to get DBMS name", |conn| { + conn.database_management_system_name() }) .await } pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> { - self.with_conn(|conn| { - Self::odbc_err(conn.execute("SELECT 1", (), None), "Ping failed")?; - Ok(()) + self.with_conn_map::<_, _, _>("Ping failed", |conn| { + conn.execute("SELECT 1", (), None).map(|_| ()) }) .await } pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> { - self.with_conn(|conn| { - Self::odbc_err(conn.set_autocommit(false), "Failed to begin transaction") + self.with_conn_map::<_, _, _>("Failed to begin transaction", |conn| { + conn.set_autocommit(false) }) .await } pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> { - self.with_conn(|conn| { - Self::odbc_err(conn.commit(), "Failed to commit transaction")?; - Self::odbc_err(conn.set_autocommit(true), "Failed to commit transaction")?; - Ok(()) + self.with_conn_map::<_, _, _>("Failed to commit transaction", |conn| { + conn.commit()?; + conn.set_autocommit(true) }) .await } pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> { - self.with_conn(|conn| { - Self::odbc_err(conn.rollback(), "Failed to rollback transaction")?; - Self::odbc_err(conn.set_autocommit(true), "Failed to rollback transaction")?; - Ok(()) + self.with_conn_map::<_, _, _>("Failed to rollback transaction", |conn| { + conn.rollback()?; + conn.set_autocommit(true) }) .await } From ea15e32175e521188b8d900522568e967bf3e4c7 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Sep 2025 13:37:19 +0000 Subject: [PATCH 06/15] feat: Cache prepared statement metadata for ODBC Co-authored-by: contact --- sqlx-core/src/odbc/connection/inner.rs | 36 +++++++++++++++++++++----- sqlx-core/src/odbc/connection/mod.rs | 14 +++++----- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/sqlx-core/src/odbc/connection/inner.rs b/sqlx-core/src/odbc/connection/inner.rs index 9d76385171..6dcf780871 100644 --- a/sqlx-core/src/odbc/connection/inner.rs +++ b/sqlx-core/src/odbc/connection/inner.rs @@ -6,8 +6,13 @@ use either::Either; use flume::{SendError, Sender}; use odbc_api::handles::StatementImpl; use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata}; +use std::collections::HashMap; -pub type OdbcConn = odbc_api::Connection<'static>; +#[derive(Debug)] +pub struct OdbcConn { + pub conn: odbc_api::Connection<'static>, + pub prepared_meta_cache: HashMap, usize)>, +} pub type ExecuteResult = Result, Error>; pub type ExecuteSender = Sender; @@ -16,17 +21,20 @@ pub fn establish_connection(options: &crate::odbc::OdbcConnectOptions) -> Result let conn = env .connect_with_connection_string(options.connection_string(), Default::default()) .map_err(|e| Error::Configuration(e.to_string().into()))?; - Ok(conn) + Ok(OdbcConn { + conn, + prepared_meta_cache: HashMap::new(), + }) } pub fn execute_sql( - conn: &mut OdbcConn, + inner: &mut OdbcConn, sql: &str, args: Option, tx: &ExecuteSender, ) -> Result<(), Error> { let params = prepare_parameters(args); - let stmt = &mut conn.preallocate().map_err(Error::from)?; + let stmt = &mut inner.conn.preallocate().map_err(Error::from)?; if let Some(mut cursor) = stmt.execute(sql, ¶ms[..])? { handle_cursor(&mut cursor, tx); @@ -314,11 +322,25 @@ fn extract_binary( } pub fn do_prepare( - conn: &mut OdbcConn, + inner: &mut OdbcConn, sql: Box, ) -> Result<(u64, Vec, usize), Error> { - let mut prepared = conn.prepare(&sql)?; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + sql.hash(&mut hasher); + let key = hasher.finish(); + + if let Some((cols, params)) = inner.prepared_meta_cache.get(&key) { + return Ok((key, cols.clone(), *params)); + } + + let mut prepared = inner.conn.prepare(&sql)?; let columns = collect_columns(&mut prepared); let params = usize::from(prepared.num_params().unwrap_or(0)); - Ok((0, columns, params)) + inner + .prepared_meta_cache + .insert(key, (columns.clone(), params)); + Ok((key, columns, params)) } diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 7a157fdb91..f3999308bd 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -70,37 +70,37 @@ impl OdbcConnection { /// connection is talking to as reported by the ODBC driver. pub async fn dbms_name(&mut self) -> Result { self.with_conn_map::<_, _, _>("Failed to get DBMS name", |conn| { - conn.database_management_system_name() + conn.conn.database_management_system_name() }) .await } pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> { self.with_conn_map::<_, _, _>("Ping failed", |conn| { - conn.execute("SELECT 1", (), None).map(|_| ()) + conn.conn.execute("SELECT 1", (), None).map(|_| ()) }) .await } pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> { self.with_conn_map::<_, _, _>("Failed to begin transaction", |conn| { - conn.set_autocommit(false) + conn.conn.set_autocommit(false) }) .await } pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> { self.with_conn_map::<_, _, _>("Failed to commit transaction", |conn| { - conn.commit()?; - conn.set_autocommit(true) + conn.conn.commit()?; + conn.conn.set_autocommit(true) }) .await } pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> { self.with_conn_map::<_, _, _>("Failed to rollback transaction", |conn| { - conn.rollback()?; - conn.set_autocommit(true) + conn.conn.rollback()?; + conn.conn.set_autocommit(true) }) .await } From 0be7f70504cb50d91758624fd2adf8636b675976 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Wed, 24 Sep 2025 22:18:08 +0200 Subject: [PATCH 07/15] refactor(odbc): rename inner module and reorganize ODBC connection implementation This commit deletes the `inner.rs` file and moves its contents to a new `odbc_bridge.rs` module, improving code organization. The `OdbcConnection` struct is updated to remove the `log_settings` field, streamlining its definition. --- sqlx-core/src/odbc/connection/mod.rs | 8 +++----- .../src/odbc/connection/{inner.rs => odbc_bridge.rs} | 0 2 files changed, 3 insertions(+), 5 deletions(-) rename sqlx-core/src/odbc/connection/{inner.rs => odbc_bridge.rs} (100%) diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index f3999308bd..05029b525b 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -1,13 +1,13 @@ -use crate::connection::{Connection, LogSettings}; +use crate::connection::Connection; use crate::error::Error; use crate::odbc::blocking::run_blocking; use crate::odbc::{Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow}; use crate::transaction::Transaction; use either::Either; -mod inner; +mod odbc_bridge; use futures_core::future::BoxFuture; use futures_util::future; -use inner::{do_prepare, establish_connection, execute_sql, OdbcConn}; +use odbc_bridge::{do_prepare, establish_connection, execute_sql, OdbcConn}; // no direct spawn_blocking here; use run_blocking helper use std::sync::{Arc, Mutex}; @@ -20,7 +20,6 @@ mod executor; #[derive(Debug)] pub struct OdbcConnection { pub(crate) inner: Arc>, - pub(crate) log_settings: LogSettings, } impl OdbcConnection { @@ -62,7 +61,6 @@ impl OdbcConnection { Ok(Self { inner: Arc::new(Mutex::new(conn)), - log_settings: LogSettings::default(), }) } diff --git a/sqlx-core/src/odbc/connection/inner.rs b/sqlx-core/src/odbc/connection/odbc_bridge.rs similarity index 100% rename from sqlx-core/src/odbc/connection/inner.rs rename to sqlx-core/src/odbc/connection/odbc_bridge.rs From 650ef81c77b74fb2c81d829100ccb9d6f9bc9c46 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Wed, 24 Sep 2025 22:18:47 +0200 Subject: [PATCH 08/15] rename inner to conn --- sqlx-core/src/odbc/connection/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 05029b525b..f723c2b5e1 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -19,7 +19,7 @@ mod executor; /// thread-pool via `spawn_blocking` and synchronize access with a mutex. #[derive(Debug)] pub struct OdbcConnection { - pub(crate) inner: Arc>, + pub(crate) conn: Arc>, } impl OdbcConnection { @@ -29,7 +29,7 @@ impl OdbcConnection { T: Send + 'static, F: FnOnce(&mut OdbcConn) -> Result + Send + 'static, { - let inner = self.inner.clone(); + let inner = self.conn.clone(); run_blocking(move || { let mut conn = inner.lock().unwrap(); f(&mut conn) @@ -44,7 +44,7 @@ impl OdbcConnection { E: std::fmt::Display, F: FnOnce(&mut OdbcConn) -> Result + Send + 'static, { - let inner = self.inner.clone(); + let inner = self.conn.clone(); run_blocking(move || { let mut conn = inner.lock().unwrap(); f(&mut conn).map_err(|e| Error::Protocol(format!("{}: {}", ctx, e))) @@ -60,7 +60,7 @@ impl OdbcConnection { .await?; Ok(Self { - inner: Arc::new(Mutex::new(conn)), + conn: Arc::new(Mutex::new(conn)), }) } From 6e90b48938dfc94f6ec63c933f61545efc87ab07 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Thu, 25 Sep 2025 00:19:59 +0200 Subject: [PATCH 09/15] refactor(odbc): restructure OdbcStatement and enhance metadata handling This commit refactors the OdbcStatement structure to encapsulate metadata, including columns and parameters, within a dedicated OdbcStatementMetadata struct. It also updates the OdbcConnection to cache prepared statement metadata, improving performance and reducing redundant metadata retrieval. Additionally, the prepare method is streamlined to utilize the new metadata structure. --- sqlx-core/src/any/connection/executor.rs | 10 +- sqlx-core/src/odbc/connection/executor.rs | 10 +- sqlx-core/src/odbc/connection/mod.rs | 199 ++++++++++++++----- sqlx-core/src/odbc/connection/odbc_bridge.rs | 54 ++--- sqlx-core/src/odbc/mod.rs | 4 +- sqlx-core/src/odbc/statement.rs | 24 ++- tests/odbc/odbc.rs | 50 +++-- 7 files changed, 203 insertions(+), 148 deletions(-) diff --git a/sqlx-core/src/any/connection/executor.rs b/sqlx-core/src/any/connection/executor.rs index 9bd5ce06e9..d49d23e543 100644 --- a/sqlx-core/src/any/connection/executor.rs +++ b/sqlx-core/src/any/connection/executor.rs @@ -128,15 +128,7 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { AnyConnectionKind::Mssql(conn) => conn.prepare(sql).await.map(Into::into)?, #[cfg(feature = "odbc")] - AnyConnectionKind::Odbc(conn) => { - let (_, columns, parameters) = conn.prepare_metadata(sql).await?; - crate::odbc::OdbcStatement { - sql: sql.into(), - columns, - parameters, - } - .into() - } + AnyConnectionKind::Odbc(conn) => conn.prepare(sql).await.map(Into::into)?, }) }) } diff --git a/sqlx-core/src/odbc/connection/executor.rs b/sqlx-core/src/odbc/connection/executor.rs index 0010a2b8e3..48d1d606ef 100644 --- a/sqlx-core/src/odbc/connection/executor.rs +++ b/sqlx-core/src/odbc/connection/executor.rs @@ -6,7 +6,6 @@ use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::TryStreamExt; -use std::borrow::Cow; // run method removed; fetch_many implements streaming directly @@ -59,14 +58,7 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection { where 'c: 'e, { - Box::pin(async move { - let (_, columns, parameters) = self.prepare_metadata(sql).await?; - Ok(OdbcStatement { - sql: Cow::Borrowed(sql), - columns, - parameters, - }) - }) + Box::pin(async move { self.prepare(sql).await }) } #[doc(hidden)] diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index f723c2b5e1..40571ce582 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -1,15 +1,48 @@ use crate::connection::Connection; use crate::error::Error; use crate::odbc::blocking::run_blocking; -use crate::odbc::{Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow}; +use crate::odbc::{ + Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, OdbcTypeInfo, +}; use crate::transaction::Transaction; use either::Either; mod odbc_bridge; use futures_core::future::BoxFuture; use futures_util::future; -use odbc_bridge::{do_prepare, establish_connection, execute_sql, OdbcConn}; +use odbc_bridge::{establish_connection, execute_sql}; // no direct spawn_blocking here; use run_blocking helper -use std::sync::{Arc, Mutex}; +use crate::odbc::{OdbcStatement, OdbcStatementMetadata}; +use odbc_api::ResultSetMetadata; +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +fn collect_columns( + prepared: &mut odbc_api::Prepared>, +) -> Vec { + let count = prepared.num_result_cols().unwrap_or(0); + (1..=count) + .map(|i| create_column(prepared, i as u16)) + .collect() +} + +fn create_column( + stmt: &mut odbc_api::Prepared>, + index: u16, +) -> OdbcColumn { + let mut cd = odbc_api::ColumnDescription::default(); + let _ = stmt.describe_col(index, &mut cd); + + OdbcColumn { + name: decode_column_name(cd.name, index), + type_info: OdbcTypeInfo::new(cd.data_type), + ordinal: usize::from(index.checked_sub(1).unwrap()), + } +} + +fn decode_column_name(name_bytes: Vec, index: u16) -> String { + String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) +} mod executor; @@ -19,86 +52,88 @@ mod executor; /// thread-pool via `spawn_blocking` and synchronize access with a mutex. #[derive(Debug)] pub struct OdbcConnection { - pub(crate) conn: Arc>, + pub(crate) conn: odbc_api::SharedConnection<'static>, + pub(crate) stmt_cache: HashMap, } impl OdbcConnection { - #[inline] - async fn with_conn(&self, f: F) -> Result - where - T: Send + 'static, - F: FnOnce(&mut OdbcConn) -> Result + Send + 'static, - { - let inner = self.conn.clone(); - run_blocking(move || { - let mut conn = inner.lock().unwrap(); - f(&mut conn) - }) - .await - } - - #[inline] - async fn with_conn_map(&self, ctx: &'static str, f: F) -> Result - where - T: Send + 'static, - E: std::fmt::Display, - F: FnOnce(&mut OdbcConn) -> Result + Send + 'static, - { - let inner = self.conn.clone(); - run_blocking(move || { - let mut conn = inner.lock().unwrap(); - f(&mut conn).map_err(|e| Error::Protocol(format!("{}: {}", ctx, e))) - }) - .await - } - pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result { - let conn = run_blocking({ + let shared_conn = run_blocking({ let options = options.clone(); - move || establish_connection(&options) + move || { + let conn = establish_connection(&options)?; + let shared_conn = odbc_api::SharedConnection::new(std::sync::Mutex::new(conn)); + Ok::<_, Error>(shared_conn) + } }) .await?; Ok(Self { - conn: Arc::new(Mutex::new(conn)), + conn: shared_conn, + stmt_cache: HashMap::new(), }) } /// Returns the name of the actual Database Management System (DBMS) this /// connection is talking to as reported by the ODBC driver. pub async fn dbms_name(&mut self) -> Result { - self.with_conn_map::<_, _, _>("Failed to get DBMS name", |conn| { - conn.conn.database_management_system_name() + let conn = Arc::clone(&self.conn); + run_blocking(move || { + let conn_guard = conn + .lock() + .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; + conn_guard + .database_management_system_name() + .map_err(Error::from) }) .await } pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> { - self.with_conn_map::<_, _, _>("Ping failed", |conn| { - conn.conn.execute("SELECT 1", (), None).map(|_| ()) + let conn = Arc::clone(&self.conn); + run_blocking(move || { + let conn_guard = conn + .lock() + .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; + conn_guard + .execute("SELECT 1", (), None) + .map_err(Error::from) + .map(|_| ()) }) .await } pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> { - self.with_conn_map::<_, _, _>("Failed to begin transaction", |conn| { - conn.conn.set_autocommit(false) + let conn = Arc::clone(&self.conn); + run_blocking(move || { + let conn_guard = conn + .lock() + .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; + conn_guard.set_autocommit(false).map_err(Error::from) }) .await } pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> { - self.with_conn_map::<_, _, _>("Failed to commit transaction", |conn| { - conn.conn.commit()?; - conn.conn.set_autocommit(true) + let conn = Arc::clone(&self.conn); + run_blocking(move || { + let conn_guard = conn + .lock() + .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; + conn_guard.commit()?; + conn_guard.set_autocommit(true).map_err(Error::from) }) .await } pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> { - self.with_conn_map::<_, _, _>("Failed to rollback transaction", |conn| { - conn.conn.rollback()?; - conn.conn.set_autocommit(true) + let conn = Arc::clone(&self.conn); + run_blocking(move || { + let conn_guard = conn + .lock() + .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; + conn_guard.rollback()?; + conn_guard.set_autocommit(true).map_err(Error::from) }) .await } @@ -111,13 +146,19 @@ impl OdbcConnection { let (tx, rx) = flume::bounded(64); let sql = sql.to_string(); let args_move = args; - self.with_conn(move |conn| { - if let Err(e) = execute_sql(conn, &sql, args_move, &tx) { + let conn = Arc::clone(&self.conn); + + run_blocking(move || { + let mut conn_guard = conn + .lock() + .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; + if let Err(e) = execute_sql(&mut conn_guard, &sql, args_move, &tx) { let _ = tx.send(Err(e)); } Ok(()) }) .await?; + Ok(rx) } @@ -125,9 +166,59 @@ impl OdbcConnection { &mut self, sql: &str, ) -> Result<(u64, Vec, usize), Error> { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + sql.hash(&mut hasher); + let key = hasher.finish(); + + // Check cache first + if let Some(metadata) = self.stmt_cache.get(&key) { + return Ok((key, metadata.columns.clone(), metadata.parameters)); + } + + // Create new prepared statement to get metadata let sql = sql.to_string(); - self.with_conn(move |conn| do_prepare(conn, sql.into())) - .await + let conn = Arc::clone(&self.conn); + + run_blocking(move || { + let conn_guard = conn + .lock() + .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; + let mut prepared = conn_guard.prepare(&sql).map_err(Error::from)?; + let columns = collect_columns(&mut prepared); + let params = usize::from(prepared.num_params().unwrap_or(0)); + Ok::<_, Error>((columns, params)) + }) + .await + .map(|(columns, params)| { + // Cache the metadata + let metadata = crate::odbc::statement::OdbcStatementMetadata { + columns: columns.clone(), + parameters: params, + }; + self.stmt_cache.insert(key, metadata); + (key, columns, params) + }) + } + + pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> { + // Clear the statement metadata cache + self.stmt_cache.clear(); + Ok(()) + } + + pub async fn prepare(&mut self, sql: &str) -> Result, Error> { + let (_, columns, parameters) = self.prepare_metadata(sql).await?; + let metadata = OdbcStatementMetadata { + columns, + parameters, + }; + Ok(OdbcStatement { + sql: Cow::Owned(sql.to_string()), + metadata, + }) } } @@ -168,6 +259,10 @@ impl Connection for OdbcConnection { fn should_flush(&self) -> bool { false } + + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(self.clear_cached_statements()) + } } // moved helpers to connection/inner.rs diff --git a/sqlx-core/src/odbc/connection/odbc_bridge.rs b/sqlx-core/src/odbc/connection/odbc_bridge.rs index 6dcf780871..f4788ce3c6 100644 --- a/sqlx-core/src/odbc/connection/odbc_bridge.rs +++ b/sqlx-core/src/odbc/connection/odbc_bridge.rs @@ -4,49 +4,45 @@ use crate::odbc::{ }; use either::Either; use flume::{SendError, Sender}; -use odbc_api::handles::StatementImpl; use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata}; -use std::collections::HashMap; -#[derive(Debug)] -pub struct OdbcConn { - pub conn: odbc_api::Connection<'static>, - pub prepared_meta_cache: HashMap, usize)>, -} pub type ExecuteResult = Result, Error>; pub type ExecuteSender = Sender; -pub fn establish_connection(options: &crate::odbc::OdbcConnectOptions) -> Result { +pub fn establish_connection( + options: &crate::odbc::OdbcConnectOptions, +) -> Result, Error> { let env = odbc_api::environment().map_err(|e| Error::Configuration(e.to_string().into()))?; let conn = env .connect_with_connection_string(options.connection_string(), Default::default()) .map_err(|e| Error::Configuration(e.to_string().into()))?; - Ok(OdbcConn { - conn, - prepared_meta_cache: HashMap::new(), - }) + Ok(conn) } pub fn execute_sql( - inner: &mut OdbcConn, + conn: &mut odbc_api::Connection<'static>, sql: &str, args: Option, tx: &ExecuteSender, ) -> Result<(), Error> { let params = prepare_parameters(args); - let stmt = &mut inner.conn.preallocate().map_err(Error::from)?; - if let Some(mut cursor) = stmt.execute(sql, ¶ms[..])? { + let mut preallocated = conn.preallocate().map_err(Error::from)?; + + if let Some(mut cursor) = preallocated.execute(sql, ¶ms[..])? { handle_cursor(&mut cursor, tx); return Ok(()); } - let affected = extract_rows_affected(stmt); + let affected = extract_rows_affected(&mut preallocated); let _ = send_done(tx, affected); Ok(()) } -fn extract_rows_affected(stmt: &mut Preallocated>) -> u64 { +fn extract_rows_affected(stmt: &mut Preallocated) -> u64 +where + S: odbc_api::handles::AsStatementRef, +{ let count_opt = match stmt.row_count() { Ok(count_opt) => count_opt, Err(_) => { @@ -320,27 +316,3 @@ fn extract_binary( float: None, }) } - -pub fn do_prepare( - inner: &mut OdbcConn, - sql: Box, -) -> Result<(u64, Vec, usize), Error> { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = DefaultHasher::new(); - sql.hash(&mut hasher); - let key = hasher.finish(); - - if let Some((cols, params)) = inner.prepared_meta_cache.get(&key) { - return Ok((key, cols.clone(), *params)); - } - - let mut prepared = inner.conn.prepare(&sql)?; - let columns = collect_columns(&mut prepared); - let params = usize::from(prepared.num_params().unwrap_or(0)); - inner - .prepared_meta_cache - .insert(key, (columns.clone(), params)); - Ok((key, columns, params)) -} diff --git a/sqlx-core/src/odbc/mod.rs b/sqlx-core/src/odbc/mod.rs index b4b4920857..8b793221b7 100644 --- a/sqlx-core/src/odbc/mod.rs +++ b/sqlx-core/src/odbc/mod.rs @@ -32,7 +32,7 @@ mod error; mod options; mod query_result; mod row; -mod statement; +pub mod statement; mod transaction; mod type_info; pub mod types; @@ -45,7 +45,7 @@ pub use database::Odbc; pub use options::OdbcConnectOptions; pub use query_result::OdbcQueryResult; pub use row::OdbcRow; -pub use statement::OdbcStatement; +pub use statement::{OdbcStatement, OdbcStatementMetadata}; pub use transaction::OdbcTransactionManager; pub use type_info::{DataTypeExt, OdbcTypeInfo}; pub use value::{OdbcValue, OdbcValueRef}; diff --git a/sqlx-core/src/odbc/statement.rs b/sqlx-core/src/odbc/statement.rs index beeef9807a..d7345deecc 100644 --- a/sqlx-core/src/odbc/statement.rs +++ b/sqlx-core/src/odbc/statement.rs @@ -8,8 +8,13 @@ use std::borrow::Cow; #[derive(Debug, Clone)] pub struct OdbcStatement<'q> { pub(crate) sql: Cow<'q, str>, - pub(crate) columns: Vec, - pub(crate) parameters: usize, + pub(crate) metadata: OdbcStatementMetadata, +} + +#[derive(Debug, Clone)] +pub struct OdbcStatementMetadata { + pub columns: Vec, + pub parameters: usize, } impl<'q> Statement<'q> for OdbcStatement<'q> { @@ -18,8 +23,7 @@ impl<'q> Statement<'q> for OdbcStatement<'q> { fn to_owned(&self) -> OdbcStatement<'static> { OdbcStatement { sql: Cow::Owned(self.sql.to_string()), - columns: self.columns.clone(), - parameters: self.parameters, + metadata: self.metadata.clone(), } } @@ -27,10 +31,10 @@ impl<'q> Statement<'q> for OdbcStatement<'q> { &self.sql } fn parameters(&self) -> Option> { - Some(Either::Right(self.parameters)) + Some(Either::Right(self.metadata.parameters)) } fn columns(&self) -> &[OdbcColumn] { - &self.columns + &self.metadata.columns } // ODBC arguments placeholder @@ -40,6 +44,7 @@ impl<'q> Statement<'q> for OdbcStatement<'q> { impl ColumnIndex> for &'_ str { fn index(&self, statement: &OdbcStatement<'_>) -> Result { statement + .metadata .columns .iter() .position(|c| c.name == *self) @@ -54,21 +59,22 @@ impl<'q> From> for crate::any::AnyStatement<'q> { // First build the columns and collect names let columns: Vec<_> = stmt + .metadata .columns - .into_iter() + .iter() .enumerate() .map(|(index, col)| { column_names.insert(crate::ext::ustr::UStr::new(&col.name), index); crate::any::AnyColumn { kind: crate::any::column::AnyColumnKind::Odbc(col.clone()), - type_info: crate::any::AnyTypeInfo::from(col.type_info), + type_info: crate::any::AnyTypeInfo::from(col.type_info.clone()), } }) .collect(); crate::any::AnyStatement { sql: stmt.sql, - parameters: Some(either::Either::Right(stmt.parameters)), + parameters: Some(either::Either::Right(stmt.metadata.parameters)), columns, column_names: std::sync::Arc::new(column_names), } diff --git a/tests/odbc/odbc.rs b/tests/odbc/odbc.rs index f92b73d881..c22dbc054e 100644 --- a/tests/odbc/odbc.rs +++ b/tests/odbc/odbc.rs @@ -149,7 +149,7 @@ async fn it_fetch_optional_some_and_none() -> anyhow::Result<()> { #[tokio::test] async fn it_can_prepare_then_query_without_params() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT 7 AS seven").await?; + let stmt = conn.prepare("SELECT 7 AS seven").await?; let row = stmt.query().fetch_one(&mut conn).await?; let col_name = row.column(0).name(); assert!( @@ -166,7 +166,7 @@ async fn it_can_prepare_then_query_without_params() -> anyhow::Result<()> { async fn it_can_prepare_then_query_with_params_integer_float_text() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT ? AS i, ? AS f, ? AS t").await?; + let stmt = conn.prepare("SELECT ? AS i, ? AS f, ? AS t").await?; let row = stmt .query() @@ -217,7 +217,7 @@ async fn it_can_bind_many_params_dynamically() -> anyhow::Result<()> { sql.push('?'); } - let stmt = (&mut conn).prepare(&sql).await?; + let stmt = conn.prepare(&sql).await?; let values: Vec = (1..=count as i32).collect(); let mut q = stmt.query(); @@ -237,7 +237,7 @@ async fn it_can_bind_many_params_dynamically() -> anyhow::Result<()> { async fn it_can_bind_heterogeneous_params() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT ?, ?, ?, ?, ?").await?; + let stmt = conn.prepare("SELECT ?, ?, ?, ?, ?").await?; let row = stmt .query() @@ -266,7 +266,7 @@ async fn it_can_bind_heterogeneous_params() -> anyhow::Result<()> { #[tokio::test] async fn it_binds_null_string_parameter() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT ?, ?").await?; + let stmt = conn.prepare("SELECT ?, ?").await?; let row = stmt .query() .bind("abc") @@ -424,7 +424,7 @@ async fn it_handles_large_strings() -> anyhow::Result<()> { // Test a moderately large string let large_string = "a".repeat(1000); - let stmt = (&mut conn).prepare("SELECT ? AS large_str").await?; + let stmt = conn.prepare("SELECT ? AS large_str").await?; let row = stmt .query() .bind(&large_string) @@ -443,7 +443,7 @@ async fn it_handles_binary_data() -> anyhow::Result<()> { // Test binary data - use UTF-8 safe bytes for PostgreSQL compatibility let binary_data = b"ABCDE"; - let stmt = (&mut conn).prepare("SELECT ? AS binary_data").await?; + let stmt = conn.prepare("SELECT ? AS binary_data").await?; let row = stmt .query() .bind(&binary_data[..]) @@ -459,7 +459,7 @@ async fn it_handles_binary_data() -> anyhow::Result<()> { async fn it_handles_mixed_null_and_values() -> anyhow::Result<()> { let mut conn = new::().await?; - let stmt = (&mut conn).prepare("SELECT ?, ?, ?, ?").await?; + let stmt = conn.prepare("SELECT ?, ?, ?, ?").await?; let row = stmt .query() .bind(42_i32) @@ -505,7 +505,7 @@ async fn it_handles_slice_types() -> anyhow::Result<()> { // Test slice types let test_data = b"Hello, ODBC!"; - let stmt = (&mut conn).prepare("SELECT ? AS slice_data").await?; + let stmt = conn.prepare("SELECT ? AS slice_data").await?; let row = stmt .query() .bind(&test_data[..]) @@ -528,7 +528,7 @@ async fn it_handles_uuid() -> anyhow::Result<()> { let uuid_str = test_uuid.to_string(); // Test UUID as string - let stmt = (&mut conn).prepare("SELECT ? AS uuid_data").await?; + let stmt = conn.prepare("SELECT ? AS uuid_data").await?; let row = stmt.query().bind(&uuid_str).fetch_one(&mut conn).await?; let result = row.try_get_raw(0)?.to_owned().decode::(); @@ -536,7 +536,7 @@ async fn it_handles_uuid() -> anyhow::Result<()> { // Test with a specific UUID string let specific_uuid_str = "550e8400-e29b-41d4-a716-446655440000"; - let stmt = (&mut conn).prepare("SELECT ? AS uuid_data").await?; + let stmt = conn.prepare("SELECT ? AS uuid_data").await?; let row = stmt .query() .bind(specific_uuid_str) @@ -563,7 +563,7 @@ async fn it_handles_json() -> anyhow::Result<()> { }); let json_str = test_json.to_string(); - let stmt = (&mut conn).prepare("SELECT ? AS json_data").await?; + let stmt = conn.prepare("SELECT ? AS json_data").await?; let row = stmt.query().bind(&json_str).fetch_one(&mut conn).await?; let result: Value = row.try_get_raw(0)?.to_owned().decode(); @@ -581,7 +581,7 @@ async fn it_handles_bigdecimal() -> anyhow::Result<()> { let test_decimal = BigDecimal::from_str("123.456789")?; let decimal_str = test_decimal.to_string(); - let stmt = (&mut conn).prepare("SELECT ? AS decimal_data").await?; + let stmt = conn.prepare("SELECT ? AS decimal_data").await?; let row = stmt.query().bind(&decimal_str).fetch_one(&mut conn).await?; let result = row.try_get_raw(0)?.to_owned().decode::(); @@ -598,7 +598,7 @@ async fn it_handles_rust_decimal() -> anyhow::Result<()> { let test_decimal = "123.456789".parse::()?; let decimal_str = test_decimal.to_string(); - let stmt = (&mut conn).prepare("SELECT ? AS decimal_data").await?; + let stmt = conn.prepare("SELECT ? AS decimal_data").await?; let row = stmt.query().bind(&decimal_str).fetch_one(&mut conn).await?; let result = row.try_get_raw(0)?.to_owned().decode::(); @@ -621,7 +621,7 @@ async fn it_handles_chrono_datetime() -> anyhow::Result<()> { let test_datetime = NaiveDateTime::new(test_date, test_time); // Test that we can encode chrono types (by storing them as strings) - let stmt = (&mut conn).prepare("SELECT ? AS date_data").await?; + let stmt = conn.prepare("SELECT ? AS date_data").await?; let row = stmt.query().bind(test_date).fetch_one(&mut conn).await?; // Decode as string and verify format @@ -629,14 +629,14 @@ async fn it_handles_chrono_datetime() -> anyhow::Result<()> { assert_eq!(result_str, "2023-12-25"); // Test time encoding - let stmt = (&mut conn).prepare("SELECT ? AS time_data").await?; + let stmt = conn.prepare("SELECT ? AS time_data").await?; let row = stmt.query().bind(test_time).fetch_one(&mut conn).await?; let result_str = row.try_get_raw(0)?.to_owned().decode::(); assert_eq!(result_str, "14:30:00"); // Test datetime encoding - let stmt = (&mut conn).prepare("SELECT ? AS datetime_data").await?; + let stmt = conn.prepare("SELECT ? AS datetime_data").await?; let row = stmt .query() .bind(test_datetime) @@ -764,7 +764,7 @@ async fn it_handles_prepare_statement_errors() -> anyhow::Result<()> { // So we test that execution fails even if preparation succeeds // Test executing prepared invalid SQL - if let Ok(stmt) = (&mut conn).prepare("INVALID PREPARE STATEMENT").await { + if let Ok(stmt) = conn.prepare("INVALID PREPARE STATEMENT").await { let result = stmt.query().fetch_one(&mut conn).await; let err = result.expect_err("should be an error"); assert!( @@ -775,7 +775,7 @@ async fn it_handles_prepare_statement_errors() -> anyhow::Result<()> { } // Test executing prepared SQL with syntax errors - match (&mut conn) + match conn .prepare("SELECT idonotexist FROM idonotexist WHERE idonotexist") .await { @@ -811,9 +811,7 @@ async fn it_handles_parameter_binding_errors() -> anyhow::Result<()> { let mut conn = new::().await?; // Test with completely missing parameters - this should more reliably fail - let stmt = (&mut conn) - .prepare("SELECT ? AS param1, ? AS param2") - .await?; + let stmt = conn.prepare("SELECT ? AS param1, ? AS param2").await?; // Test with no parameters when some are expected let result = stmt.query().fetch_one(&mut conn).await; @@ -824,7 +822,7 @@ async fn it_handles_parameter_binding_errors() -> anyhow::Result<()> { // Test that we can handle parameter binding gracefully // Even if the driver is permissive, the system should be robust - let stmt2 = (&mut conn).prepare("SELECT ? AS single_param").await?; + let stmt2 = conn.prepare("SELECT ? AS single_param").await?; // Bind correct number of parameters - this should work let result = stmt2.query().bind(42i32).fetch_one(&mut conn).await; @@ -842,7 +840,7 @@ async fn it_handles_parameter_execution_errors() -> anyhow::Result<()> { let mut conn = new::().await?; // Test parameter binding with incompatible operations that should fail at execution - let stmt = (&mut conn).prepare("SELECT ? / 0 AS div_by_zero").await?; + let stmt = conn.prepare("SELECT ? / 0 AS div_by_zero").await?; // This should execute but may produce a runtime error (division by zero) let result = stmt.query().bind(42i32).fetch_one(&mut conn).await; @@ -850,7 +848,7 @@ async fn it_handles_parameter_execution_errors() -> anyhow::Result<()> { let _ = result; // Test with a parameter in an invalid context that should fail - if let Ok(stmt) = (&mut conn).prepare("SELECT * FROM ?").await { + if let Ok(stmt) = conn.prepare("SELECT * FROM ?").await { // Using parameter as table name should fail at execution let result = stmt .query() @@ -1019,7 +1017,7 @@ async fn it_handles_prepared_statement_with_wrong_parameters() -> anyhow::Result let mut conn = new::().await?; // Prepare a statement expecting specific parameter types - let stmt = (&mut conn).prepare("SELECT ? + ? AS sum").await?; + let stmt = conn.prepare("SELECT ? + ? AS sum").await?; // Test binding incompatible types (if the database is strict about types) // Some databases/drivers are permissive, others are strict From 6148a217c4e6830878ef4ca3563264dedb6f206e Mon Sep 17 00:00:00 2001 From: lovasoa Date: Thu, 25 Sep 2025 00:44:16 +0200 Subject: [PATCH 10/15] refactor(odbc): implement with_conn helper for OdbcConnection methods This commit introduces the `with_conn` method to streamline connection handling in various OdbcConnection methods, including `dbms_name`, `ping_blocking`, `begin_blocking`, `commit_blocking`, `rollback_blocking`, `execute_stream`, and `prepare_metadata`. This refactor enhances code readability and reduces duplication by encapsulating the connection locking logic. --- sqlx-core/src/odbc/connection/mod.rs | 88 ++++++++++++---------------- 1 file changed, 37 insertions(+), 51 deletions(-) diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 40571ce582..b1d3f275a4 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -57,6 +57,22 @@ pub struct OdbcConnection { } impl OdbcConnection { + pub(crate) async fn with_conn(&mut self, operation: S, f: F) -> Result + where + R: Send + 'static, + F: FnOnce(&mut odbc_api::Connection<'static>) -> Result + Send + 'static, + S: std::fmt::Display + Send + 'static, + { + let conn = Arc::clone(&self.conn); + run_blocking(move || { + let mut conn_guard = conn.lock().map_err(|_| { + Error::Protocol(format!("ODBC {}: failed to lock connection", operation)) + })?; + f(&mut conn_guard) + }) + .await + } + pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result { let shared_conn = run_blocking({ let options = options.clone(); @@ -77,63 +93,42 @@ impl OdbcConnection { /// Returns the name of the actual Database Management System (DBMS) this /// connection is talking to as reported by the ODBC driver. pub async fn dbms_name(&mut self) -> Result { - let conn = Arc::clone(&self.conn); - run_blocking(move || { - let conn_guard = conn - .lock() - .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; - conn_guard - .database_management_system_name() - .map_err(Error::from) + self.with_conn("dbms_name", move |conn| { + Ok(conn.database_management_system_name()?) }) .await } pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> { - let conn = Arc::clone(&self.conn); - run_blocking(move || { - let conn_guard = conn - .lock() - .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; - conn_guard - .execute("SELECT 1", (), None) - .map_err(Error::from) - .map(|_| ()) + self.with_conn("ping", move |conn| { + conn.execute("SELECT 1", (), None)?; + Ok(()) }) .await } pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> { - let conn = Arc::clone(&self.conn); - run_blocking(move || { - let conn_guard = conn - .lock() - .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; - conn_guard.set_autocommit(false).map_err(Error::from) + self.with_conn("begin", move |conn| { + conn.set_autocommit(false)?; + Ok(()) }) .await } pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> { - let conn = Arc::clone(&self.conn); - run_blocking(move || { - let conn_guard = conn - .lock() - .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; - conn_guard.commit()?; - conn_guard.set_autocommit(true).map_err(Error::from) + self.with_conn("commit", move |conn| { + conn.commit()?; + conn.set_autocommit(true)?; + Ok(()) }) .await } pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> { - let conn = Arc::clone(&self.conn); - run_blocking(move || { - let conn_guard = conn - .lock() - .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; - conn_guard.rollback()?; - conn_guard.set_autocommit(true).map_err(Error::from) + self.with_conn("rollback", move |conn| { + conn.rollback()?; + conn.set_autocommit(true)?; + Ok(()) }) .await } @@ -146,13 +141,9 @@ impl OdbcConnection { let (tx, rx) = flume::bounded(64); let sql = sql.to_string(); let args_move = args; - let conn = Arc::clone(&self.conn); - run_blocking(move || { - let mut conn_guard = conn - .lock() - .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; - if let Err(e) = execute_sql(&mut conn_guard, &sql, args_move, &tx) { + self.with_conn("execute_stream", move |conn| { + if let Err(e) = execute_sql(conn, &sql, args_move, &tx) { let _ = tx.send(Err(e)); } Ok(()) @@ -180,16 +171,11 @@ impl OdbcConnection { // Create new prepared statement to get metadata let sql = sql.to_string(); - let conn = Arc::clone(&self.conn); - - run_blocking(move || { - let conn_guard = conn - .lock() - .map_err(|_| Error::Protocol("Failed to lock connection".into()))?; - let mut prepared = conn_guard.prepare(&sql).map_err(Error::from)?; + self.with_conn("prepare_metadata", move |conn| { + let mut prepared = conn.prepare(&sql)?; let columns = collect_columns(&mut prepared); let params = usize::from(prepared.num_params().unwrap_or(0)); - Ok::<_, Error>((columns, params)) + Ok((columns, params)) }) .await .map(|(columns, params)| { From 0e780921d7b9a4e40214bcd2bc57abd75fb10a15 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Thu, 25 Sep 2025 03:03:30 +0200 Subject: [PATCH 11/15] refactor(odbc): enhance OdbcConnection and SQL execution handling This commit refactors the OdbcConnection structure to utilize a new type for prepared statements, improving type safety and clarity. It also modifies the execute_sql function to handle both prepared and non-prepared SQL statements through a new MaybePrepared enum, streamlining execution logic. Additionally, the prepare method is updated to cache prepared statements more effectively, enhancing performance. --- sqlx-core/src/odbc/connection/mod.rs | 105 ++++++++----------- sqlx-core/src/odbc/connection/odbc_bridge.rs | 56 +++++----- 2 files changed, 78 insertions(+), 83 deletions(-) diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index b1d3f275a4..78f6ea75b0 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -9,27 +9,27 @@ use either::Either; mod odbc_bridge; use futures_core::future::BoxFuture; use futures_util::future; +use odbc_api::ConnectionTransitions; use odbc_bridge::{establish_connection, execute_sql}; // no direct spawn_blocking here; use run_blocking helper use crate::odbc::{OdbcStatement, OdbcStatementMetadata}; -use odbc_api::ResultSetMetadata; +use odbc_api::{handles::StatementConnection, Prepared, ResultSetMetadata, SharedConnection}; use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; -fn collect_columns( - prepared: &mut odbc_api::Prepared>, -) -> Vec { +mod executor; + +type PreparedStatement = Prepared>>; + +fn collect_columns(prepared: &mut PreparedStatement) -> Vec { let count = prepared.num_result_cols().unwrap_or(0); (1..=count) .map(|i| create_column(prepared, i as u16)) .collect() } -fn create_column( - stmt: &mut odbc_api::Prepared>, - index: u16, -) -> OdbcColumn { +fn create_column(stmt: &mut PreparedStatement, index: u16) -> OdbcColumn { let mut cd = odbc_api::ColumnDescription::default(); let _ = stmt.describe_col(index, &mut cd); @@ -44,16 +44,21 @@ fn decode_column_name(name_bytes: Vec, index: u16) -> String { String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) } -mod executor; - /// A connection to an ODBC-accessible database. /// /// ODBC uses a blocking C API, so we offload blocking calls to the runtime's blocking /// thread-pool via `spawn_blocking` and synchronize access with a mutex. -#[derive(Debug)] pub struct OdbcConnection { - pub(crate) conn: odbc_api::SharedConnection<'static>, - pub(crate) stmt_cache: HashMap, + pub(crate) conn: SharedConnection<'static>, + pub(crate) stmt_cache: HashMap, PreparedStatement>, +} + +impl std::fmt::Debug for OdbcConnection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OdbcConnection") + .field("conn", &self.conn) + .finish() + } } impl OdbcConnection { @@ -139,11 +144,16 @@ impl OdbcConnection { args: Option, ) -> Result, Error>>, Error> { let (tx, rx) = flume::bounded(64); - let sql = sql.to_string(); - let args_move = args; + + // !!TODO!!!: Put back the prepared statement after usage + let maybe_prepared = if let Some(prepared) = self.stmt_cache.remove(sql) { + MaybePrepared::Prepared(prepared) + } else { + MaybePrepared::NotPrepared(sql.to_string()) + }; self.with_conn("execute_stream", move |conn| { - if let Err(e) = execute_sql(conn, &sql, args_move, &tx) { + if let Err(e) = execute_sql(conn, maybe_prepared, args, &tx) { let _ = tx.send(Err(e)); } Ok(()) @@ -153,61 +163,38 @@ impl OdbcConnection { Ok(rx) } - pub(crate) async fn prepare_metadata( - &mut self, - sql: &str, - ) -> Result<(u64, Vec, usize), Error> { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = DefaultHasher::new(); - sql.hash(&mut hasher); - let key = hasher.finish(); - - // Check cache first - if let Some(metadata) = self.stmt_cache.get(&key) { - return Ok((key, metadata.columns.clone(), metadata.parameters)); - } - - // Create new prepared statement to get metadata - let sql = sql.to_string(); - self.with_conn("prepare_metadata", move |conn| { - let mut prepared = conn.prepare(&sql)?; - let columns = collect_columns(&mut prepared); - let params = usize::from(prepared.num_params().unwrap_or(0)); - Ok((columns, params)) - }) - .await - .map(|(columns, params)| { - // Cache the metadata - let metadata = crate::odbc::statement::OdbcStatementMetadata { - columns: columns.clone(), - parameters: params, - }; - self.stmt_cache.insert(key, metadata); - (key, columns, params) - }) - } - pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> { // Clear the statement metadata cache self.stmt_cache.clear(); Ok(()) } - pub async fn prepare(&mut self, sql: &str) -> Result, Error> { - let (_, columns, parameters) = self.prepare_metadata(sql).await?; - let metadata = OdbcStatementMetadata { - columns, - parameters, - }; + pub async fn prepare<'a>(&mut self, sql: &'a str) -> Result, Error> { + let conn = Arc::clone(&self.conn); + let sql_arc = Arc::from(sql.to_string()); + let sql_clone = Arc::clone(&sql_arc); + let (prepared, metadata) = run_blocking(move || { + let mut prepared = conn.into_prepared(&sql_clone)?; + let metadata = OdbcStatementMetadata { + columns: collect_columns(&mut prepared), + parameters: usize::from(prepared.num_params().unwrap_or(0)), + }; + Ok((prepared, metadata)) + }) + .await?; + self.stmt_cache.insert(Arc::clone(&sql_arc), prepared); Ok(OdbcStatement { - sql: Cow::Owned(sql.to_string()), + sql: Cow::Borrowed(sql), metadata, }) } } +pub(crate) enum MaybePrepared { + Prepared(PreparedStatement), + NotPrepared(String), +} + impl Connection for OdbcConnection { type Database = Odbc; diff --git a/sqlx-core/src/odbc/connection/odbc_bridge.rs b/sqlx-core/src/odbc/connection/odbc_bridge.rs index f4788ce3c6..31a3e9e066 100644 --- a/sqlx-core/src/odbc/connection/odbc_bridge.rs +++ b/sqlx-core/src/odbc/connection/odbc_bridge.rs @@ -1,10 +1,12 @@ use crate::error::Error; use crate::odbc::{ - OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult, OdbcRow, OdbcTypeInfo, + connection::MaybePrepared, OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult, + OdbcRow, OdbcTypeInfo, }; use either::Either; use flume::{SendError, Sender}; -use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata}; +use odbc_api::handles::{AsStatementRef, Statement}; +use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, ResultSetMetadata}; pub type ExecuteResult = Result, Error>; pub type ExecuteSender = Sender; @@ -21,43 +23,49 @@ pub fn establish_connection( pub fn execute_sql( conn: &mut odbc_api::Connection<'static>, - sql: &str, + maybe_prepared: MaybePrepared, args: Option, tx: &ExecuteSender, ) -> Result<(), Error> { let params = prepare_parameters(args); - let mut preallocated = conn.preallocate().map_err(Error::from)?; - - if let Some(mut cursor) = preallocated.execute(sql, ¶ms[..])? { - handle_cursor(&mut cursor, tx); - return Ok(()); - } + let affected = match maybe_prepared { + MaybePrepared::Prepared(mut prepared) => { + if let Some(mut cursor) = prepared.execute(¶ms[..])? { + handle_cursor(&mut cursor, tx); + } + extract_rows_affected(&mut prepared) + } + MaybePrepared::NotPrepared(sql) => { + let mut preallocated = conn.preallocate().map_err(Error::from)?; + if let Some(mut cursor) = preallocated.execute(&sql, ¶ms[..])? { + handle_cursor(&mut cursor, tx); + } + extract_rows_affected(&mut preallocated) + } + }; - let affected = extract_rows_affected(&mut preallocated); let _ = send_done(tx, affected); Ok(()) } -fn extract_rows_affected(stmt: &mut Preallocated) -> u64 -where - S: odbc_api::handles::AsStatementRef, -{ - let count_opt = match stmt.row_count() { - Ok(count_opt) => count_opt, - Err(_) => { +fn extract_rows_affected(stmt: &mut S) -> u64 { + let mut stmt_ref = stmt.as_stmt_ref(); + let count = match stmt_ref.row_count().into_result(&stmt_ref) { + Ok(count) => count, + Err(e) => { + log::warn!("Failed to get row count: {}", e); return 0; } }; - let count = match count_opt { - Some(count) => count, - None => { - return 0; + match u64::try_from(count) { + Ok(count) => count, + Err(e) => { + log::warn!("Failed to get row count: {}", e); + 0 } - }; - - u64::try_from(count).unwrap_or_default() + } } fn prepare_parameters( From b7277b037982a89cf64427775f749ab442c56111 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Thu, 25 Sep 2025 10:13:49 +0200 Subject: [PATCH 12/15] fix(odbc): reuse prepared statements across queries --- sqlx-core/src/odbc/connection/mod.rs | 14 ++++++++------ sqlx-core/src/odbc/connection/odbc_bridge.rs | 5 +++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 78f6ea75b0..50e26c41c9 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -16,11 +16,12 @@ use crate::odbc::{OdbcStatement, OdbcStatementMetadata}; use odbc_api::{handles::StatementConnection, Prepared, ResultSetMetadata, SharedConnection}; use std::borrow::Cow; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; mod executor; type PreparedStatement = Prepared>>; +type SharedPreparedStatement = Arc>; fn collect_columns(prepared: &mut PreparedStatement) -> Vec { let count = prepared.num_result_cols().unwrap_or(0); @@ -50,7 +51,7 @@ fn decode_column_name(name_bytes: Vec, index: u16) -> String { /// thread-pool via `spawn_blocking` and synchronize access with a mutex. pub struct OdbcConnection { pub(crate) conn: SharedConnection<'static>, - pub(crate) stmt_cache: HashMap, PreparedStatement>, + pub(crate) stmt_cache: HashMap, SharedPreparedStatement>, } impl std::fmt::Debug for OdbcConnection { @@ -146,8 +147,8 @@ impl OdbcConnection { let (tx, rx) = flume::bounded(64); // !!TODO!!!: Put back the prepared statement after usage - let maybe_prepared = if let Some(prepared) = self.stmt_cache.remove(sql) { - MaybePrepared::Prepared(prepared) + let maybe_prepared = if let Some(prepared) = self.stmt_cache.get(sql) { + MaybePrepared::Prepared(Arc::clone(prepared)) } else { MaybePrepared::NotPrepared(sql.to_string()) }; @@ -182,7 +183,8 @@ impl OdbcConnection { Ok((prepared, metadata)) }) .await?; - self.stmt_cache.insert(Arc::clone(&sql_arc), prepared); + self.stmt_cache + .insert(Arc::clone(&sql_arc), Arc::new(Mutex::new(prepared))); Ok(OdbcStatement { sql: Cow::Borrowed(sql), metadata, @@ -191,7 +193,7 @@ impl OdbcConnection { } pub(crate) enum MaybePrepared { - Prepared(PreparedStatement), + Prepared(SharedPreparedStatement), NotPrepared(String), } diff --git a/sqlx-core/src/odbc/connection/odbc_bridge.rs b/sqlx-core/src/odbc/connection/odbc_bridge.rs index 31a3e9e066..d0e20262e3 100644 --- a/sqlx-core/src/odbc/connection/odbc_bridge.rs +++ b/sqlx-core/src/odbc/connection/odbc_bridge.rs @@ -30,11 +30,12 @@ pub fn execute_sql( let params = prepare_parameters(args); let affected = match maybe_prepared { - MaybePrepared::Prepared(mut prepared) => { + MaybePrepared::Prepared(prepared) => { + let mut prepared = prepared.lock().expect("prepared statement lock"); if let Some(mut cursor) = prepared.execute(¶ms[..])? { handle_cursor(&mut cursor, tx); } - extract_rows_affected(&mut prepared) + extract_rows_affected(&mut *prepared) } MaybePrepared::NotPrepared(sql) => { let mut preallocated = conn.preallocate().map_err(Error::from)?; From e6ffb26f8808c7bd63f9c22ec86ac5a34016f018 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Thu, 25 Sep 2025 11:13:33 +0200 Subject: [PATCH 13/15] refactor(odbc): remove run_blocking helper and directly use spawn_blocking This commit eliminates the run_blocking function from the ODBC module, replacing its usage with the spawn_blocking function directly in the OdbcConnection implementation. This change simplifies the code and enhances clarity in handling blocking tasks. --- sqlx-core/src/odbc/blocking.rs | 19 ------------------- sqlx-core/src/odbc/connection/mod.rs | 15 +++++++-------- sqlx-core/src/odbc/mod.rs | 1 - sqlx-rt/src/rt_tokio.rs | 15 +++++++++++++-- 4 files changed, 20 insertions(+), 30 deletions(-) delete mode 100644 sqlx-core/src/odbc/blocking.rs diff --git a/sqlx-core/src/odbc/blocking.rs b/sqlx-core/src/odbc/blocking.rs deleted file mode 100644 index b25b657139..0000000000 --- a/sqlx-core/src/odbc/blocking.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::error::Error; -use sqlx_rt::spawn_blocking; - -pub async fn run_blocking(f: F) -> Result -where - R: Send + 'static, - F: FnOnce() -> Result + Send + 'static, -{ - #[cfg(feature = "_rt-tokio")] - { - let join_result = spawn_blocking(f).await.map_err(|_| Error::WorkerCrashed)?; - join_result - } - - #[cfg(feature = "_rt-async-std")] - { - spawn_blocking(f).await - } -} diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 50e26c41c9..9572e611f0 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -1,19 +1,18 @@ use crate::connection::Connection; use crate::error::Error; -use crate::odbc::blocking::run_blocking; use crate::odbc::{ Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, OdbcTypeInfo, }; use crate::transaction::Transaction; use either::Either; +use sqlx_rt::spawn_blocking; mod odbc_bridge; +use crate::odbc::{OdbcStatement, OdbcStatementMetadata}; use futures_core::future::BoxFuture; use futures_util::future; use odbc_api::ConnectionTransitions; -use odbc_bridge::{establish_connection, execute_sql}; -// no direct spawn_blocking here; use run_blocking helper -use crate::odbc::{OdbcStatement, OdbcStatementMetadata}; use odbc_api::{handles::StatementConnection, Prepared, ResultSetMetadata, SharedConnection}; +use odbc_bridge::{establish_connection, execute_sql}; use std::borrow::Cow; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -70,7 +69,7 @@ impl OdbcConnection { S: std::fmt::Display + Send + 'static, { let conn = Arc::clone(&self.conn); - run_blocking(move || { + spawn_blocking(move || { let mut conn_guard = conn.lock().map_err(|_| { Error::Protocol(format!("ODBC {}: failed to lock connection", operation)) })?; @@ -80,7 +79,7 @@ impl OdbcConnection { } pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result { - let shared_conn = run_blocking({ + let shared_conn = spawn_blocking({ let options = options.clone(); move || { let conn = establish_connection(&options)?; @@ -174,13 +173,13 @@ impl OdbcConnection { let conn = Arc::clone(&self.conn); let sql_arc = Arc::from(sql.to_string()); let sql_clone = Arc::clone(&sql_arc); - let (prepared, metadata) = run_blocking(move || { + let (prepared, metadata) = spawn_blocking(move || { let mut prepared = conn.into_prepared(&sql_clone)?; let metadata = OdbcStatementMetadata { columns: collect_columns(&mut prepared), parameters: usize::from(prepared.num_params().unwrap_or(0)), }; - Ok((prepared, metadata)) + Ok::<_, Error>((prepared, metadata)) }) .await?; self.stmt_cache diff --git a/sqlx-core/src/odbc/mod.rs b/sqlx-core/src/odbc/mod.rs index 8b793221b7..492cc370b6 100644 --- a/sqlx-core/src/odbc/mod.rs +++ b/sqlx-core/src/odbc/mod.rs @@ -24,7 +24,6 @@ use crate::executor::Executor; mod arguments; -mod blocking; mod column; mod connection; mod database; diff --git a/sqlx-rt/src/rt_tokio.rs b/sqlx-rt/src/rt_tokio.rs index 855ff6269f..72b2cbb27b 100644 --- a/sqlx-rt/src/rt_tokio.rs +++ b/sqlx-rt/src/rt_tokio.rs @@ -1,7 +1,7 @@ pub use tokio::{ self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf, - net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::spawn_blocking, - task::yield_now, time::sleep, time::timeout, + net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now, + time::sleep, time::timeout, }; #[cfg(unix)] @@ -45,3 +45,14 @@ pub fn test_block_on(future: F) -> F::Output { .expect("failed to initialize Tokio test runtime") .block_on(future) } + +/// Spawn a blocking task. Panics if the task panics. +pub async fn spawn_blocking(f: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + tokio::task::spawn_blocking(f) + .await + .expect("blocking task panicked") +} From 19e66fc4594b91bbe1367fba561f66805f76e91c Mon Sep 17 00:00:00 2001 From: lovasoa Date: Thu, 25 Sep 2025 11:24:30 +0200 Subject: [PATCH 14/15] refactor(odbc): simplify execute_stream implementation This commit refactors the execute_stream method in OdbcConnection to directly return the receiver from the execute_stream function, eliminating unnecessary complexity. Additionally, the method signature is updated to reflect the change in return type, enhancing clarity in the codebase. --- sqlx-core/src/odbc/connection/executor.rs | 9 +-------- sqlx-core/src/odbc/connection/mod.rs | 18 +++++++++--------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/sqlx-core/src/odbc/connection/executor.rs b/sqlx-core/src/odbc/connection/executor.rs index 48d1d606ef..21ab6c01f2 100644 --- a/sqlx-core/src/odbc/connection/executor.rs +++ b/sqlx-core/src/odbc/connection/executor.rs @@ -20,15 +20,8 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection { 'c: 'e, E: Execute<'q, Self::Database> + 'q, { - let sql = query.sql().to_string(); let args = query.take_arguments(); - Box::pin(try_stream! { - let rx = self.execute_stream(&sql, args).await?; - while let Ok(item) = rx.recv_async().await { - r#yield!(item?); - } - Ok(()) - }) + Box::pin(self.execute_stream(query.sql(), args).into_stream()) } fn fetch_optional<'e, 'q: 'e, E>( diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 9572e611f0..84d5572c84 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -138,29 +138,29 @@ impl OdbcConnection { .await } - pub(crate) async fn execute_stream( + /// Launches a background task to execute the SQL statement and send the results to the returned channel. + pub(crate) fn execute_stream( &mut self, sql: &str, args: Option, - ) -> Result, Error>>, Error> { + ) -> flume::Receiver, Error>> { let (tx, rx) = flume::bounded(64); - // !!TODO!!!: Put back the prepared statement after usage let maybe_prepared = if let Some(prepared) = self.stmt_cache.get(sql) { MaybePrepared::Prepared(Arc::clone(prepared)) } else { MaybePrepared::NotPrepared(sql.to_string()) }; - self.with_conn("execute_stream", move |conn| { - if let Err(e) = execute_sql(conn, maybe_prepared, args, &tx) { + let conn = Arc::clone(&self.conn); + sqlx_rt::spawn(sqlx_rt::spawn_blocking(move || { + let mut conn = conn.lock().expect("failed to lock connection"); + if let Err(e) = execute_sql(&mut conn, maybe_prepared, args, &tx) { let _ = tx.send(Err(e)); } - Ok(()) - }) - .await?; + })); - Ok(rx) + rx } pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> { From 5922a7cde10bad1607767b3c4b196f2207058f48 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Thu, 25 Sep 2025 11:35:16 +0200 Subject: [PATCH 15/15] refactor(odbc): streamline fetch_many implementation in Executor This commit updates the fetch_many method in the Executor implementation for OdbcConnection to utilize into_future and then, simplifying the handling of query results and enhancing code clarity. --- sqlx-core/src/odbc/connection/executor.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/sqlx-core/src/odbc/connection/executor.rs b/sqlx-core/src/odbc/connection/executor.rs index 21ab6c01f2..b62ebe1531 100644 --- a/sqlx-core/src/odbc/connection/executor.rs +++ b/sqlx-core/src/odbc/connection/executor.rs @@ -5,7 +5,7 @@ use crate::odbc::{Odbc, OdbcConnection, OdbcQueryResult, OdbcRow, OdbcStatement, use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use futures_util::TryStreamExt; +use futures_util::{future, FutureExt, StreamExt}; // run method removed; fetch_many implements streaming directly @@ -32,15 +32,12 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection { 'c: 'e, E: Execute<'q, Self::Database> + 'q, { - let mut s = self.fetch_many(query); - Box::pin(async move { - while let Some(v) = s.try_next().await? { - if let Either::Right(r) = v { - return Ok(Some(r)); - } - } - Ok(None) - }) + Box::pin(self.fetch_many(query).into_future().then(|(v, _)| match v { + Some(Ok(Either::Right(r))) => future::ok(Some(r)), + Some(Ok(Either::Left(_))) => future::ok(None), + Some(Err(e)) => future::err(e), + None => future::ok(None), + })) } fn prepare_with<'e, 'q: 'e>(