From 905a38122b037979df81e27135b362435fb808f1 Mon Sep 17 00:00:00 2001 From: Peter Neumark Date: Thu, 12 Jan 2023 22:51:21 +0100 Subject: [PATCH 1/3] Strip quotes from schemas and tables prior to writing them to catalog --- src/repository/default.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/repository/default.rs b/src/repository/default.rs index fc34552c..f48e32a9 100644 --- a/src/repository/default.rs +++ b/src/repository/default.rs @@ -164,8 +164,9 @@ impl Repository for $repo { } async fn create_database(&self, database_name: &str) -> Result { + let unqoted_name = database_name.trim_matches('"'); let id = sqlx::query(r#"INSERT INTO database (name) VALUES ($1) RETURNING (id)"#) - .bind(database_name) + .bind(unqoted_name) .fetch_one(&self.executor) .await.map_err($repo::interpret_error)? .try_get("id").map_err($repo::interpret_error)?; @@ -212,9 +213,10 @@ impl Repository for $repo { database_id: DatabaseId, collection_name: &str, ) -> Result { + let unquoted_name = collection_name.trim_matches('"'); let id = sqlx::query( r#"INSERT INTO "collection" (database_id, name) VALUES ($1, $2) RETURNING (id)"#, - ).bind(database_id).bind(collection_name) + ).bind(database_id).bind(unquoted_name) .fetch_one(&self.executor) .await.map_err($repo::interpret_error)? .try_get("id").map_err($repo::interpret_error)?; @@ -229,11 +231,12 @@ impl Repository for $repo { schema: &Schema, ) -> Result<(TableId, TableVersionId), Error> { // Create new (empty) table + let unquoted_name = table_name.trim_matches('"'); let new_table_id: i64 = sqlx::query( r#"INSERT INTO "table" (collection_id, name) VALUES ($1, $2) RETURNING (id)"#, ) .bind(collection_id) - .bind(table_name) + .bind(unquoted_name) .fetch_one(&self.executor) .await.map_err($repo::interpret_error)? .try_get("id").map_err($repo::interpret_error)?; @@ -515,6 +518,7 @@ impl Repository for $repo { details: &CreateFunctionDetails, ) -> Result { let input_types = serde_json::to_string(&details.input_types).expect("Couldn't serialize input types!"); + let unquoted_name = function_name.trim_matches('"'); let new_function_id: i64 = sqlx::query( r#" @@ -522,7 +526,7 @@ impl Repository for $repo { VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING (id); "#) .bind(database_id) - .bind(function_name) + .bind(unquoted_name) .bind(details.entrypoint.clone()) .bind(details.language.to_string()) .bind(input_types) From 2a3739a10846740b8ff78ba401659e29f72b6296 Mon Sep 17 00:00:00 2001 From: Peter Neumark Date: Fri, 13 Jan 2023 12:45:59 +0100 Subject: [PATCH 2/3] integration tests --- src/config/context.rs | 2 +- src/repository/interface.rs | 88 ++++++++++++++++++++++++++++++++----- 2 files changed, 77 insertions(+), 13 deletions(-) diff --git a/src/config/context.rs b/src/config/context.rs index 02fef27e..0ea31b8b 100644 --- a/src/config/context.rs +++ b/src/config/context.rs @@ -89,7 +89,7 @@ fn build_object_store(cfg: &schema::SeafowlConfig) -> Arc { let mut builder = AmazonS3Builder::new() .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key) - .with_region(region.clone().unwrap_or("".to_string())) + .with_region(region.clone().unwrap_or_default()) .with_bucket_name(bucket) .with_allow_http(true); diff --git a/src/repository/interface.rs b/src/repository/interface.rs index 79234bf9..02bf8942 100644 --- a/src/repository/interface.rs +++ b/src/repository/interface.rs @@ -242,15 +242,19 @@ pub mod tests { } } - async fn make_database_with_single_table( + // bumped into rustc bug: https://github.com/rust-lang/rust/issues/96771#issuecomment-1119886703 + async fn make_database_with_single_table<'a>( repository: Arc, + database_name: &'a str, + collection_name: &'a str, + table_name: &'a str, ) -> (DatabaseId, CollectionId, TableId, TableVersionId) { let database_id = repository - .create_database("testdb") + .create_database(database_name) .await .expect("Error creating database"); let collection_id = repository - .create_collection(database_id, "testcol") + .create_collection(database_id, collection_name) .await .expect("Error creating collection"); @@ -263,7 +267,7 @@ pub mod tests { }; let (table_id, table_version_id) = repository - .create_table(collection_id, "testtable", &schema) + .create_table(collection_id, table_name, &schema) .await .expect("Error creating table"); @@ -279,6 +283,7 @@ pub mod tests { test_create_functions(repository.clone(), database_id).await; test_rename_table(repository.clone(), database_id, table_id, new_version_id) .await; + test_create_with_name_in_quotes(repository.clone()).await; test_error_propagation(repository, table_id).await; } @@ -296,12 +301,13 @@ pub mod tests { version: TableVersionId, collection_name: String, table_name: String, + table_id: i64, ) -> Vec { vec![ AllDatabaseColumnsResult { collection_name: collection_name.clone(), table_name: table_name.clone(), - table_id: 1, + table_id, table_version_id: version, column_name: "date".to_string(), column_type: "{\"children\":[],\"name\":\"date\",\"nullable\":false,\"type\":{\"name\":\"date\",\"unit\":\"MILLISECOND\"}}".to_string(), @@ -309,7 +315,7 @@ pub mod tests { AllDatabaseColumnsResult { collection_name, table_name, - table_id: 1, + table_id, table_version_id: version, column_name: "value".to_string(), column_type: "{\"children\":[],\"name\":\"value\",\"nullable\":false,\"type\":{\"name\":\"floatingpoint\",\"precision\":\"DOUBLE\"}}" @@ -322,7 +328,13 @@ pub mod tests { repository: Arc, ) -> (DatabaseId, TableId, TableVersionId) { let (database_id, _, table_id, table_version_id) = - make_database_with_single_table(repository.clone()).await; + make_database_with_single_table( + repository.clone(), + "testdb", + "testcol", + "testtable", + ) + .await; // Test loading all columns @@ -333,7 +345,7 @@ pub mod tests { assert_eq!( all_columns, - expected(1, "testcol".to_string(), "testtable".to_string()) + expected(1, "testcol".to_string(), "testtable".to_string(), 1) ); // Duplicate the table @@ -353,7 +365,8 @@ pub mod tests { expected( new_version_id, "testcol".to_string(), - "testtable".to_string() + "testtable".to_string(), + 1 ) ); @@ -365,7 +378,7 @@ pub mod tests { assert_eq!( all_columns, - expected(1, "testcol".to_string(), "testtable".to_string()) + expected(1, "testcol".to_string(), "testtable".to_string(), 1) ); // Check the existing table versions @@ -382,6 +395,55 @@ pub mod tests { (database_id, table_id, table_version_id) } + async fn test_create_with_name_in_quotes(repository: Arc) { + let (database_id, _, _, table_version_id) = make_database_with_single_table( + repository.clone(), + "testdb2", + "\"testcol\"", + "\"testtable\"", + ) + .await; + + // Test loading all columns + + let all_columns = repository + .get_all_columns_in_database(database_id, None) + .await + .expect("Error getting all columns"); + + assert_eq!( + all_columns, + expected( + table_version_id, + "testcol".to_string(), + "testtable".to_string(), + 2 + ) + ); + + // Duplicate the table + let new_version_id = repository + .create_new_table_version(table_version_id, true) + .await + .unwrap(); + + // Test all columns again: we should have the schema for the latest table version + let all_columns = repository + .get_all_columns_in_database(database_id, None) + .await + .expect("Error getting all columns"); + + assert_eq!( + all_columns, + expected( + new_version_id, + "testcol".to_string(), + "testtable".to_string(), + 2 + ) + ); + } + async fn test_create_append_partition( repository: Arc, table_version_id: TableVersionId, @@ -529,7 +591,8 @@ pub mod tests { expected( table_version_id, "testcol".to_string(), - "testtable2".to_string() + "testtable2".to_string(), + 1 ) ); @@ -553,7 +616,8 @@ pub mod tests { expected( table_version_id, "testcol2".to_string(), - "testtable2".to_string() + "testtable2".to_string(), + 1 ) ); } From e4851c083ff52ae5929d4914d423610c59f7c48b Mon Sep 17 00:00:00 2001 From: Peter Neumark Date: Mon, 16 Jan 2023 16:33:32 +0100 Subject: [PATCH 3/3] Clippy-induced improvements --- datafusion_remote_tables/src/factory.rs | 6 ++--- src/wasm_udf/data_types.rs | 14 ++++------- src/wasm_udf/wasm.rs | 32 ++++++++++++------------- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/datafusion_remote_tables/src/factory.rs b/datafusion_remote_tables/src/factory.rs index 97c8736e..93874e9c 100644 --- a/datafusion_remote_tables/src/factory.rs +++ b/datafusion_remote_tables/src/factory.rs @@ -22,9 +22,9 @@ impl TableProviderFactory for RemoteTableFactory { let table = RemoteTable::new( cmd.options .get("name") - .ok_or(DataFusionError::Execution( - "Missing 'name' option".to_string(), - ))? + .ok_or_else(|| { + DataFusionError::Execution("Missing 'name' option".to_string()) + })? .clone(), cmd.location.clone(), SchemaRef::from(cmd.schema.deref().clone()), diff --git a/src/wasm_udf/data_types.rs b/src/wasm_udf/data_types.rs index 3d34a5d7..616b7cdf 100644 --- a/src/wasm_udf/data_types.rs +++ b/src/wasm_udf/data_types.rs @@ -62,28 +62,22 @@ pub enum CreateFunctionDataType { #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, EnumString, Display, Clone)] #[serde(rename_all = "camelCase")] +#[derive(Default)] pub enum CreateFunctionVolatility { Immutable, Stable, + #[default] Volatile, } -impl Default for CreateFunctionVolatility { - fn default() -> Self { - CreateFunctionVolatility::Volatile - } -} #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, EnumString, Display, Clone)] #[serde(rename_all = "camelCase")] +#[derive(Default)] pub enum CreateFunctionLanguage { + #[default] Wasm, WasmMessagePack, } -impl Default for CreateFunctionLanguage { - fn default() -> Self { - CreateFunctionLanguage::Wasm - } -} fn parse_create_function_data_type( raw: &str, diff --git a/src/wasm_udf/wasm.rs b/src/wasm_udf/wasm.rs index 424083eb..e595a6d6 100644 --- a/src/wasm_udf/wasm.rs +++ b/src/wasm_udf/wasm.rs @@ -112,11 +112,11 @@ impl WasmMessagePackUDFInstance { let alloc = get_wasm_module_exported_fn(&instance, &mut store, "alloc")?; let dealloc = get_wasm_module_exported_fn(&instance, &mut store, "dealloc")?; let udf = get_wasm_module_exported_fn(&instance, &mut store, function_name)?; - let memory = instance.get_memory(&mut store, "memory").ok_or( + let memory = instance.get_memory(&mut store, "memory").ok_or_else(|| { DataFusionError::Internal( "could not find module's exported memory".to_string(), - ), - )?; + ) + })?; Ok(Self { store, alloc, @@ -376,7 +376,7 @@ fn messagepack_decode_results( arrow::datatypes::Int16Type, >(encoded_results, &|v| { v.as_i64() - .ok_or(DataFusionError::Internal(format!( + .ok_or_else(|| DataFusionError::Internal(format!( "Expected to find i64 value, but received {v:?} instead" ))) .and_then(|v_i64| { @@ -392,7 +392,7 @@ fn messagepack_decode_results( encoded_results, &|v| { v.as_i64() - .ok_or(DataFusionError::Internal(format!( + .ok_or_else(|| DataFusionError::Internal(format!( "Expected to find i64 value, but received {v:?} instead" ))) .and_then(|v_i64| { @@ -409,7 +409,7 @@ fn messagepack_decode_results( decode_udf_result_primitive_array::( encoded_results, &|v| { - v.as_i64().ok_or(DataFusionError::Internal(format!( + v.as_i64().ok_or_else(|| DataFusionError::Internal(format!( "Expected to find i64 value, but received {v:?} instead" ))) }, @@ -420,7 +420,7 @@ fn messagepack_decode_results( | CreateFunctionDataType::TEXT => encoded_results .iter() .map(|i| { - Some(i.as_str().ok_or(DataFusionError::Internal(format!( + Some(i.as_str().ok_or_else(|| DataFusionError::Internal(format!( "Expected to find string value, received {:?} instead", &i )))) @@ -432,7 +432,7 @@ fn messagepack_decode_results( arrow::datatypes::Date32Type, >(encoded_results, &|v| { v.as_i64() - .ok_or(DataFusionError::Internal(format!( + .ok_or_else(|| DataFusionError::Internal(format!( "Expected to find i64 value, but received {v:?} instead" ))) .and_then(|v_i64| { @@ -446,14 +446,14 @@ fn messagepack_decode_results( CreateFunctionDataType::TIMESTAMP => decode_udf_result_primitive_array::< arrow::datatypes::TimestampNanosecondType, >(encoded_results, &|v| { - v.as_i64().ok_or(DataFusionError::Internal(format!( + v.as_i64().ok_or_else(|| DataFusionError::Internal(format!( "Expected to find i64 value, but received {v:?} instead" ))) }), CreateFunctionDataType::BOOLEAN => encoded_results .iter() .map(|i| { - Some(i.as_bool().ok_or(DataFusionError::Internal(format!( + Some(i.as_bool().ok_or_else(|| DataFusionError::Internal(format!( "Expected to find string value, received {i:?} instead" )))) .transpose() @@ -465,7 +465,7 @@ fn messagepack_decode_results( decode_udf_result_primitive_array::( encoded_results, &|v| { - v.as_f64().ok_or(DataFusionError::Internal(format!( + v.as_f64().ok_or_else(|| DataFusionError::Internal(format!( "Expected to find f64 value, but received {v:?} instead" ))) }, @@ -491,7 +491,7 @@ fn messagepack_decode_results( .map(|i| { Some( i.as_array() - .ok_or(DataFusionError::Internal(format!( + .ok_or_else(|| DataFusionError::Internal(format!( "Expected to find array containing decimal parts, received {i:?} instead" ))) .and_then(|decimal_array| { @@ -499,7 +499,7 @@ fn messagepack_decode_results( return Err(DataFusionError::Internal(format!("DECIMAL UDF result array should have 4 elements, found {:?} instead.", decimal_array.len()))); } decimal_array[0].as_u64() - .ok_or(DataFusionError::Internal(format!("Decimal precision expected to be integer, found {:?} instead", decimal_array[0]))) + .ok_or_else(|| DataFusionError::Internal(format!("Decimal precision expected to be integer, found {:?} instead", decimal_array[0]))) .and_then(|p_u64| { let p_u8:u8 = p_u64.try_into().map_err(|err| DataFusionError::Internal(format!("Couldn't convert 64-bit precision value {p_u64:?} to u8 {err:?}")))?; if p_u8 != *p { @@ -508,7 +508,7 @@ fn messagepack_decode_results( Ok(p_u8) })?; decimal_array[1].as_u64() - .ok_or(DataFusionError::Internal(format!("Decimal scale expected to be integer, found {:?} instead", decimal_array[1]))) + .ok_or_else(|| DataFusionError::Internal(format!("Decimal scale expected to be integer, found {:?} instead", decimal_array[1]))) .and_then(|s_u64| { let s_i8: i8 = s_u64.try_into().map_err(|err| DataFusionError::Internal(format!("Couldn't convert 64-bit scale value {s_u64:?} to i8 {err:?}")))?; if s_i8 != *s { @@ -517,9 +517,9 @@ fn messagepack_decode_results( Ok(s_i8) })?; let high = decimal_array[2].as_i64() - .ok_or(DataFusionError::Internal(format!("Decimal value high half expected to be integer, found {:?} instead", decimal_array[2])))?; + .ok_or_else(|| DataFusionError::Internal(format!("Decimal value high half expected to be integer, found {:?} instead", decimal_array[2])))?; let low = decimal_array[3].as_i64() - .ok_or(DataFusionError::Internal(format!("Decimal value low half expected to be integer, found {:?} instead", decimal_array[3])))?; + .ok_or_else(|| DataFusionError::Internal(format!("Decimal value low half expected to be integer, found {:?} instead", decimal_array[3])))?; let value:i128 = (low as i128) + ((high as i128) << 64); Ok(value) }),