Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,31 @@ use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_t
use crate::execution::memory_pools::logging_pool::LoggingMemoryPool;
use crate::execution::spark_config::{
SparkConfig, COMET_DEBUG_ENABLED, COMET_DEBUG_MEMORY, COMET_EXPLAIN_NATIVE_ENABLED,
COMET_MAX_TEMP_DIRECTORY_SIZE, COMET_TRACING_ENABLED,
COMET_MAX_TEMP_DIRECTORY_SIZE, COMET_TRACING_ENABLED, SPARK_EXECUTOR_CORES,
};
use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID};
use datafusion_comet_proto::spark_operator::operator::OpStruct;
use log::info;
use once_cell::sync::Lazy;
use std::sync::OnceLock;
#[cfg(feature = "jemalloc")]
use tikv_jemalloc_ctl::{epoch, stats};

static TOKIO_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
static TOKIO_RUNTIME: OnceLock<Runtime> = OnceLock::new();

fn parse_usize_env_var(name: &str) -> Option<usize> {
std::env::var_os(name).and_then(|n| n.to_str().and_then(|s| s.parse::<usize>().ok()))
}

fn build_runtime(default_worker_threads: Option<usize>) -> Runtime {
let mut builder = tokio::runtime::Builder::new_multi_thread();
if let Some(n) = parse_usize_env_var("COMET_WORKER_THREADS") {
info!("Comet tokio runtime: using COMET_WORKER_THREADS={n}");
builder.worker_threads(n);
} else if let Some(n) = default_worker_threads {
info!("Comet tokio runtime: using spark.executor.cores={n} worker threads");
builder.worker_threads(n);
} else {
info!("Comet tokio runtime: using default thread count");
}
if let Some(n) = parse_usize_env_var("COMET_MAX_BLOCKING_THREADS") {
builder.max_blocking_threads(n);
Expand All @@ -108,15 +120,17 @@ static TOKIO_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
.enable_all()
.build()
.expect("Failed to create Tokio runtime")
});
}

fn parse_usize_env_var(name: &str) -> Option<usize> {
std::env::var_os(name).and_then(|n| n.to_str().and_then(|s| s.parse::<usize>().ok()))
/// Initialize the global Tokio runtime with the given default worker thread count.
/// If the runtime is already initialized, this is a no-op.
pub fn init_runtime(default_worker_threads: usize) {
TOKIO_RUNTIME.get_or_init(|| build_runtime(Some(default_worker_threads)));
}

/// Function to get a handle to the global Tokio runtime
pub fn get_runtime() -> &'static Runtime {
&TOKIO_RUNTIME
TOKIO_RUNTIME.get_or_init(|| build_runtime(None))
}

/// Comet native execution context. Kept alive across JNI calls.
Expand Down Expand Up @@ -192,6 +206,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
let spark_configs = serde::deserialize_config(bytes.as_slice())?;
let spark_config: HashMap<String, String> = spark_configs.entries.into_iter().collect();

// Initialize the tokio runtime with spark.executor.cores as the default
// worker thread count, falling back to 1 if not set.
let executor_cores = spark_config.get_usize(SPARK_EXECUTOR_CORES, 1);
init_runtime(executor_cores);

// Access Comet configs
let debug_native = spark_config.get_bool(COMET_DEBUG_ENABLED);
let explain_native = spark_config.get_bool(COMET_EXPLAIN_NATIVE_ENABLED);
Expand Down
8 changes: 8 additions & 0 deletions native/core/src/execution/spark_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ pub(crate) const COMET_DEBUG_ENABLED: &str = "spark.comet.debug.enabled";
pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.native.enabled";
pub(crate) const COMET_MAX_TEMP_DIRECTORY_SIZE: &str = "spark.comet.maxTempDirectorySize";
pub(crate) const COMET_DEBUG_MEMORY: &str = "spark.comet.debug.memory";
pub(crate) const SPARK_EXECUTOR_CORES: &str = "spark.executor.cores";

pub(crate) trait SparkConfig {
fn get_bool(&self, name: &str) -> bool;
fn get_u64(&self, name: &str, default_value: u64) -> u64;
fn get_usize(&self, name: &str, default_value: usize) -> usize;
}

impl SparkConfig for HashMap<String, String> {
Expand All @@ -40,4 +42,10 @@ impl SparkConfig for HashMap<String, String> {
.and_then(|str_val| str_val.parse::<u64>().ok())
.unwrap_or(default_value)
}

fn get_usize(&self, name: &str, default_value: usize) -> usize {
self.get(name)
.and_then(|str_val| str_val.parse::<usize>().ok())
.unwrap_or(default_value)
}
}
5 changes: 5 additions & 0 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ object CometExecIterator extends Logging {
builder.putEntries(k, v)
}
}
// Inject the resolved executor cores so the native side can use it
// for tokio runtime thread count
val executorCores = numDriverOrExecutorCores(SparkEnv.get.conf)
builder.putEntries("spark.executor.cores", executorCores.toString)

builder.build().toByteArray
}

Expand Down
Loading