diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index bafde3a0c2..0193f3012c 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -87,19 +87,31 @@ use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_t use crate::execution::memory_pools::logging_pool::LoggingMemoryPool; use crate::execution::spark_config::{ SparkConfig, COMET_DEBUG_ENABLED, COMET_DEBUG_MEMORY, COMET_EXPLAIN_NATIVE_ENABLED, - COMET_MAX_TEMP_DIRECTORY_SIZE, COMET_TRACING_ENABLED, + COMET_MAX_TEMP_DIRECTORY_SIZE, COMET_TRACING_ENABLED, SPARK_EXECUTOR_CORES, }; use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; use datafusion_comet_proto::spark_operator::operator::OpStruct; use log::info; -use once_cell::sync::Lazy; +use std::sync::OnceLock; #[cfg(feature = "jemalloc")] use tikv_jemalloc_ctl::{epoch, stats}; -static TOKIO_RUNTIME: Lazy = Lazy::new(|| { +static TOKIO_RUNTIME: OnceLock = OnceLock::new(); + +fn parse_usize_env_var(name: &str) -> Option { + std::env::var_os(name).and_then(|n| n.to_str().and_then(|s| s.parse::().ok())) +} + +fn build_runtime(default_worker_threads: Option) -> Runtime { let mut builder = tokio::runtime::Builder::new_multi_thread(); if let Some(n) = parse_usize_env_var("COMET_WORKER_THREADS") { + info!("Comet tokio runtime: using COMET_WORKER_THREADS={n}"); + builder.worker_threads(n); + } else if let Some(n) = default_worker_threads { + info!("Comet tokio runtime: using spark.executor.cores={n} worker threads"); builder.worker_threads(n); + } else { + info!("Comet tokio runtime: using default thread count"); } if let Some(n) = parse_usize_env_var("COMET_MAX_BLOCKING_THREADS") { builder.max_blocking_threads(n); @@ -108,15 +120,17 @@ static TOKIO_RUNTIME: Lazy = Lazy::new(|| { .enable_all() .build() .expect("Failed to create Tokio runtime") -}); +} -fn parse_usize_env_var(name: &str) -> Option { - std::env::var_os(name).and_then(|n| n.to_str().and_then(|s| s.parse::().ok())) +/// Initialize the global Tokio runtime with the given default worker thread count. +/// If the runtime is already initialized, this is a no-op. +pub fn init_runtime(default_worker_threads: usize) { + TOKIO_RUNTIME.get_or_init(|| build_runtime(Some(default_worker_threads))); } /// Function to get a handle to the global Tokio runtime pub fn get_runtime() -> &'static Runtime { - &TOKIO_RUNTIME + TOKIO_RUNTIME.get_or_init(|| build_runtime(None)) } /// Comet native execution context. Kept alive across JNI calls. @@ -192,6 +206,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let spark_configs = serde::deserialize_config(bytes.as_slice())?; let spark_config: HashMap = spark_configs.entries.into_iter().collect(); + // Initialize the tokio runtime with spark.executor.cores as the default + // worker thread count, falling back to 1 if not set. + let executor_cores = spark_config.get_usize(SPARK_EXECUTOR_CORES, 1); + init_runtime(executor_cores); + // Access Comet configs let debug_native = spark_config.get_bool(COMET_DEBUG_ENABLED); let explain_native = spark_config.get_bool(COMET_EXPLAIN_NATIVE_ENABLED); diff --git a/native/core/src/execution/spark_config.rs b/native/core/src/execution/spark_config.rs index b257a5ba68..277c0eb43b 100644 --- a/native/core/src/execution/spark_config.rs +++ b/native/core/src/execution/spark_config.rs @@ -22,10 +22,12 @@ pub(crate) const COMET_DEBUG_ENABLED: &str = "spark.comet.debug.enabled"; pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.native.enabled"; pub(crate) const COMET_MAX_TEMP_DIRECTORY_SIZE: &str = "spark.comet.maxTempDirectorySize"; pub(crate) const COMET_DEBUG_MEMORY: &str = "spark.comet.debug.memory"; +pub(crate) const SPARK_EXECUTOR_CORES: &str = "spark.executor.cores"; pub(crate) trait SparkConfig { fn get_bool(&self, name: &str) -> bool; fn get_u64(&self, name: &str, default_value: u64) -> u64; + fn get_usize(&self, name: &str, default_value: usize) -> usize; } impl SparkConfig for HashMap { @@ -40,4 +42,10 @@ impl SparkConfig for HashMap { .and_then(|str_val| str_val.parse::().ok()) .unwrap_or(default_value) } + + fn get_usize(&self, name: &str, default_value: usize) -> usize { + self.get(name) + .and_then(|str_val| str_val.parse::().ok()) + .unwrap_or(default_value) + } } diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index d27f88b496..f17d8f4f72 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -278,6 +278,11 @@ object CometExecIterator extends Logging { builder.putEntries(k, v) } } + // Inject the resolved executor cores so the native side can use it + // for tokio runtime thread count + val executorCores = numDriverOrExecutorCores(SparkEnv.get.conf) + builder.putEntries("spark.executor.cores", executorCores.toString) + builder.build().toByteArray }