diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 79322b182..ac884009a 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -18,6 +18,14 @@ //! Define JNI APIs which can be called from Java/Scala. use super::{serde, utils::SparkArrowConvert, CometMemoryPool}; +use crate::{ + errors::{try_unwrap_or_throw, CometError, CometResult}, + execution::{ + datafusion::planner::PhysicalPlanner, metrics::utils::update_comet_metric, + serde::to_arrow_datatype, shuffle::row::process_sorted_row_partition, sort::RdxSort, + }, + jvm_bridge::{jni_new_global_ref, JVMClasses}, +}; use arrow::datatypes::DataType as ArrowDataType; use arrow_array::RecordBatch; use datafusion::{ @@ -28,7 +36,13 @@ use datafusion::{ physical_plan::{display::DisplayableExecutionPlan, ExecutionPlan, SendableRecordBatchStream}, prelude::{SessionConfig, SessionContext}, }; +use datafusion_comet_proto::spark_operator::Operator; +use datafusion_common::ScalarValue; +use datafusion_execution::memory_pool::{ + FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, +}; use futures::poll; +use futures::stream::StreamExt; use jni::{ errors::Result as JNIResult, objects::{ @@ -38,37 +52,25 @@ use jni::{ sys::{jbyteArray, jint, jlong, jlongArray}, JNIEnv, }; -use std::num::NonZeroUsize; -use std::{collections::HashMap, sync::Arc, task::Poll}; - -use crate::{ - errors::{try_unwrap_or_throw, CometError, CometResult}, - execution::{ - datafusion::planner::PhysicalPlanner, metrics::utils::update_comet_metric, - serde::to_arrow_datatype, shuffle::row::process_sorted_row_partition, sort::RdxSort, - }, - jvm_bridge::{jni_new_global_ref, JVMClasses}, -}; -use datafusion_comet_proto::spark_operator::Operator; -use datafusion_common::ScalarValue; -use datafusion_execution::memory_pool::{ - FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, -}; -use futures::stream::StreamExt; use jni::{ objects::GlobalRef, sys::{jboolean, jdouble, jintArray, jobjectArray, jstring}, }; +use std::num::NonZeroUsize; +use std::sync::Mutex; +use std::{collections::HashMap, sync::Arc, task::Poll}; use tokio::runtime::Runtime; use crate::execution::operators::ScanExec; use log::info; -use once_cell::sync::OnceCell; +use once_cell::sync::{Lazy, OnceCell}; /// Comet native execution context. Kept alive across JNI calls. struct ExecutionContext { /// The id of the execution context. pub id: i64, + /// Task attempt id + pub task_attempt_id: i64, /// The deserialized Spark plan pub spark_plan: Operator, /// The DataFusion root operator converted from the `spark_plan` @@ -91,6 +93,51 @@ struct ExecutionContext { pub debug_native: bool, /// Whether to write native plans with metrics to stdout pub explain_native: bool, + /// Memory pool config + pub memory_pool_config: MemoryPoolConfig, +} + +#[derive(PartialEq, Eq)] +enum MemoryPoolType { + Default, + Unified, + Greedy, + FairSpill, + FairSpillTaskShared, + GreedyGlobal, + FairSpillGlobal, +} + +struct MemoryPoolConfig { + pool_type: MemoryPoolType, + pool_size: usize, +} + +impl MemoryPoolConfig { + fn new(pool_type: MemoryPoolType, pool_size: usize) -> Self { + Self { + pool_type, + pool_size, + } + } +} + +/// The per-task memory pools keyed by task attempt id. +static TASK_SHARED_MEMORY_POOLS: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +struct PerTaskMemoryPool { + memory_pool: Arc, + num_plans: usize, +} + +impl PerTaskMemoryPool { + fn new(memory_pool: Arc) -> Self { + Self { + memory_pool, + num_plans: 0, + } + } } /// Accept serialized query plan and return the address of the native query plan. @@ -106,7 +153,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( serialized_query: jbyteArray, metrics_node: JObject, comet_task_memory_manager_obj: JObject, - memory_pool_address: jlong, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { // Init JVM classes @@ -143,6 +189,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( .map(String::as_str) .unwrap_or("10") .parse::()?; + let task_attempt_id: i64 = configs.get("task_attempt_id").unwrap().parse()?; // Use multi-threaded tokio runtime to prevent blocking spawned tasks if any let runtime = tokio::runtime::Builder::new_multi_thread() @@ -165,14 +212,18 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let task_memory_manager = Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?); + let memory_pool_config = parse_memory_pool_config(&configs)?; + let memory_pool = + create_memory_pool(&memory_pool_config, task_memory_manager, task_attempt_id); + // We need to keep the session context alive. Some session state like temporary // dictionaries are stored in session context. If it is dropped, the temporary // dictionaries will be dropped as well. - let session = - prepare_datafusion_session_context(&configs, task_memory_manager, memory_pool_address)?; + let session = prepare_datafusion_session_context(&configs, memory_pool)?; let exec_context = Box::new(ExecutionContext { id, + task_attempt_id, spark_plan, root_op: None, scans: vec![], @@ -184,6 +235,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( session_ctx: Arc::new(session), debug_native, explain_native, + memory_pool_config, }); Ok(Box::into_raw(exec_context) as i64) @@ -193,8 +245,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( /// Parse Comet configs and configure DataFusion session context. fn prepare_datafusion_session_context( conf: &HashMap, - comet_task_memory_manager: Arc, - memory_pool_address: jlong, + memory_pool: Option>, ) -> CometResult { // Get the batch size from Comet JVM side let batch_size = conf @@ -206,71 +257,8 @@ fn prepare_datafusion_session_context( let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs); - // Check if we are using unified memory manager integrated with Spark. Default to false if not - // set. - let use_unified_memory_manager = parse_bool(conf, "use_unified_memory_manager")?; - - if use_unified_memory_manager { - // Set Comet memory pool for native - let memory_pool = CometMemoryPool::new(comet_task_memory_manager); - rt_config = rt_config.with_memory_pool(Arc::new(memory_pool)); - } else if memory_pool_address != 0 { - // Use the task-shared memory pool allocated by `createTaskMemoryPool` - let memory_pool = unsafe { - let pool = memory_pool_address as *mut Arc>; - Box::from_raw(pool) - }; - let memory_pool_arc = Arc::clone(Box::leak(memory_pool)); - rt_config = rt_config.with_memory_pool(memory_pool_arc); - } else { - // Use the memory pool from DF - if conf.contains_key("memory_limit") || conf.contains_key("memory_limit_per_task") { - let memory_limit = conf.get("memory_limit").unwrap().parse::()?; - let memory_limit_per_task = conf - .get("memory_limit_per_task") - .unwrap() - .parse::()?; - let memory_fraction = conf - .get("memory_fraction") - .ok_or(CometError::Internal( - "Config 'memory_fraction' is not specified from Comet JVM side".to_string(), - ))? - .parse::()?; - let pool_size = (memory_limit as f64 * memory_fraction) as usize; - let pool_size_per_task = (memory_limit_per_task as f64 * memory_fraction) as usize; - let memory_pool_type = conf.get("memory_pool_type").unwrap(); - if memory_pool_type == "fair_spill_global" { - static GLOBAL_MEMORY_POOL_FAIR: OnceCell> = OnceCell::new(); - let memory_pool = GLOBAL_MEMORY_POOL_FAIR.get_or_init(|| { - Arc::new(TrackConsumersPool::new( - FairSpillPool::new(pool_size), - NonZeroUsize::new(10).unwrap(), - )) - }); - rt_config = rt_config.with_memory_pool(Arc::clone(memory_pool)); - } else if memory_pool_type == "greedy_global" { - static GLOBAL_MEMORY_POOL_GREEDY: OnceCell> = OnceCell::new(); - let memory_pool = GLOBAL_MEMORY_POOL_GREEDY.get_or_init(|| { - Arc::new(TrackConsumersPool::new( - GreedyMemoryPool::new(pool_size), - NonZeroUsize::new(10).unwrap(), - )) - }); - rt_config = rt_config.with_memory_pool(Arc::clone(memory_pool)); - } else if memory_pool_type == "fair_spill" { - rt_config = rt_config.with_memory_pool(Arc::new(TrackConsumersPool::new( - FairSpillPool::new(pool_size_per_task), - NonZeroUsize::new(10).unwrap(), - ))); - } else if memory_pool_type == "greedy" { - rt_config = rt_config.with_memory_limit(memory_limit_per_task, memory_fraction); - } else { - return Err(CometError::Config(format!( - "Unsupported memory pool type: {}", - memory_pool_type - ))); - } - } + if let Some(memory_pool) = memory_pool { + rt_config = rt_config.with_memory_pool(memory_pool); } // Get Datafusion configuration from Spark Execution context @@ -302,6 +290,113 @@ fn prepare_datafusion_session_context( Ok(session_ctx) } +fn parse_memory_pool_config(conf: &HashMap) -> CometResult { + // Check if we are using unified memory manager integrated with Spark. Default to false if not + // set. + let use_unified_memory_manager = parse_bool(conf, "use_unified_memory_manager")?; + + let memory_pool_config = if use_unified_memory_manager { + MemoryPoolConfig::new(MemoryPoolType::Unified, 0) + } else { + // Use the memory pool from DF + if conf.contains_key("memory_limit") || conf.contains_key("memory_limit_per_task") { + let memory_pool_type = conf.get("memory_pool_type").unwrap(); + let memory_limit = conf.get("memory_limit").unwrap().parse::()?; + let memory_fraction = conf + .get("memory_fraction") + .ok_or(CometError::Internal( + "Config 'memory_fraction' is not specified from Comet JVM side".to_string(), + ))? + .parse::()?; + let pool_size = (memory_limit as f64 * memory_fraction) as usize; + + let memory_limit_per_task = conf.get("memory_limit_per_task").unwrap().parse::(); + let pool_size_per_task = + memory_limit_per_task.map(|v| (v as f64 * memory_fraction) as usize); + match memory_pool_type.as_str() { + "fair_spill_task_shared" => { + MemoryPoolConfig::new(MemoryPoolType::FairSpillTaskShared, pool_size_per_task?) + } + "fair_spill_global" => { + MemoryPoolConfig::new(MemoryPoolType::FairSpillGlobal, pool_size) + } + "greedy_global" => MemoryPoolConfig::new(MemoryPoolType::GreedyGlobal, pool_size), + "fair_spill" => { + MemoryPoolConfig::new(MemoryPoolType::FairSpill, pool_size_per_task?) + } + "greedy" => MemoryPoolConfig::new(MemoryPoolType::Greedy, pool_size_per_task?), + _ => { + return Err(CometError::Config(format!( + "Unsupported memory pool type: {}", + memory_pool_type + ))) + } + } + } else { + MemoryPoolConfig::new(MemoryPoolType::Default, 0) + } + }; + Ok(memory_pool_config) +} + +fn create_memory_pool( + memory_pool_config: &MemoryPoolConfig, + comet_task_memory_manager: Arc, + task_attempt_id: i64, +) -> Option> { + match memory_pool_config.pool_type { + MemoryPoolType::Unified => { + // Set Comet memory pool for native + let memory_pool = CometMemoryPool::new(comet_task_memory_manager); + Some(Arc::new(memory_pool)) + } + MemoryPoolType::Greedy => Some(Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(10).unwrap(), + ))), + MemoryPoolType::FairSpill => Some(Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(10).unwrap(), + ))), + MemoryPoolType::GreedyGlobal => { + static GLOBAL_MEMORY_POOL_GREEDY: OnceCell> = OnceCell::new(); + let memory_pool = GLOBAL_MEMORY_POOL_GREEDY.get_or_init(|| { + Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(10).unwrap(), + )) + }); + Some(Arc::clone(memory_pool)) + } + MemoryPoolType::FairSpillGlobal => { + static GLOBAL_MEMORY_POOL_FAIR: OnceCell> = OnceCell::new(); + let memory_pool = GLOBAL_MEMORY_POOL_FAIR.get_or_init(|| { + Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(10).unwrap(), + )) + }); + Some(Arc::clone(memory_pool)) + } + MemoryPoolType::FairSpillTaskShared => { + let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap(); + let per_task_memory_pool = + memory_pool_map.entry(task_attempt_id).or_insert_with(|| { + PerTaskMemoryPool::new(Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_pool_config.pool_size), + NonZeroUsize::new(10).unwrap(), + ))) + }); + per_task_memory_pool.num_plans += 1; + Some(Arc::clone(&per_task_memory_pool.memory_pool)) + } + MemoryPoolType::Default => { + // Use the memory pool from DF + None + } + } +} + fn parse_bool(conf: &HashMap, name: &str) -> CometResult { conf.get(name) .map(String::as_str) @@ -487,6 +582,20 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( ) { try_unwrap_or_throw(&e, |_| unsafe { let execution_context = get_execution_context(exec_context); + if execution_context.memory_pool_config.pool_type == MemoryPoolType::FairSpillTaskShared { + // Decrement the number of native plans using the per-task shared memory pool, and + // remove the memory pool if the released native plan is the last native plan using it. + let task_attempt_id = execution_context.task_attempt_id; + let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap(); + if let Some(per_task_memory_pool) = memory_pool_map.get_mut(&task_attempt_id) { + per_task_memory_pool.num_plans -= 1; + if per_task_memory_pool.num_plans == 0 { + // Drop the memory pool from the per-task memory pool map if there are no + // more native plans using it. + memory_pool_map.remove(&task_attempt_id); + } + } + } let _: Box = Box::from_raw(execution_context); Ok(()) }) @@ -616,32 +725,3 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( Ok(()) }) } - -#[no_mangle] -pub extern "system" fn Java_org_apache_comet_Native_createTaskMemoryPool( - e: JNIEnv, - _class: JClass, - memory_limit_per_task: jlong, - memory_fraction: jdouble, -) -> jlong { - try_unwrap_or_throw(&e, |_| { - let pool_size_per_task = (memory_limit_per_task as f64 * memory_fraction) as usize; - let memory_pool = Box::new(Arc::new(TrackConsumersPool::new( - FairSpillPool::new(pool_size_per_task), - NonZeroUsize::new(10).unwrap(), - ))); - Ok(Box::into_raw(memory_pool) as jlong) - }) -} - -#[no_mangle] -pub extern "system" fn Java_org_apache_comet_Native_releaseTaskMemoryPool( - e: JNIEnv, - _class: JClass, - pool_address: jlong, -) { - try_unwrap_or_throw(&e, |_| unsafe { - let _ = Box::from_raw(pool_address as *mut Arc>); - Ok(()) - }) -} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index fd49f1977..4f9a5f612 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -19,8 +19,6 @@ package org.apache.comet -import java.util.Map - import org.apache.spark._ import org.apache.spark.sql.comet.CometMetricNode import org.apache.spark.sql.vectorized._ @@ -57,8 +55,7 @@ class CometExecIterator( }.toArray private val plan = { val configs = createNativeConf - CometExecIterator.createPlan( - nativeLib, + nativeLib.createPlan( id, configs, cometBatchIterators, @@ -101,6 +98,10 @@ class CometExecIterator( result.put("worker_threads", String.valueOf(COMET_WORKER_THREADS.get())) result.put("blocking_threads", String.valueOf(COMET_BLOCKING_THREADS.get())) + val taskContext = TaskContext.get() + val taskAttemptId = taskContext.taskAttemptId() + result.put("task_attempt_id", String.valueOf(taskAttemptId)) + // Strip mandatory prefix spark. which is not required for DataFusion session params conf.getAll.foreach { case (k, v) if k.startsWith("spark.datafusion") => @@ -199,47 +200,3 @@ class CometExecIterator( } } } - -object CometExecIterator { - private val taskMemoryPoolAddressMap = new java.util.concurrent.ConcurrentHashMap[Long, Long]() - - private def createPlan( - nativeLib: Native, - id: Long, - configMap: Map[String, String], - iterators: Array[CometBatchIterator], - protobufQueryPlan: Array[Byte], - metrics: CometMetricNode, - taskMemoryManager: CometTaskMemoryManager): Long = { - val taskContext = TaskContext.get() - val taskAttemptId = taskContext.taskAttemptId() - - val poolAddress = - if (!configMap.get("use_unified_memory_manager").toBoolean && configMap.get( - "memory_pool_type") == "fair_spill_task_shared") { - taskMemoryPoolAddressMap.computeIfAbsent( - taskAttemptId, - _ => { - val memoryLimitPerTask = configMap.get("memory_limit_per_task").toLong - val memoryFraction = configMap.get("memory_fraction").toDouble - val poolAddress = nativeLib.createTaskMemoryPool(memoryLimitPerTask, memoryFraction) - taskContext.addTaskCompletionListener[Unit] { _ => - nativeLib.releaseTaskMemoryPool(poolAddress) - taskMemoryPoolAddressMap.remove(taskAttemptId) - } - poolAddress - }) - } else { - 0 - } - - nativeLib.createPlan( - id, - configMap, - iterators, - protobufQueryPlan, - metrics, - taskMemoryManager, - poolAddress) - } -} diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 5eedac539..03a9dea0c 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -42,8 +42,6 @@ class Native extends NativeBase { * @param taskMemoryManager * the task-level memory manager that is responsible for tracking memory usage across JVM and * native side. - * @param memoryPoolAddress - * the address of the task-level memory pool. * @return * the address to native query plan. */ @@ -53,8 +51,7 @@ class Native extends NativeBase { iterators: Array[CometBatchIterator], plan: Array[Byte], metrics: CometMetricNode, - taskMemoryManager: CometTaskMemoryManager, - memoryPoolAddress: Long): Long + taskMemoryManager: CometTaskMemoryManager): Long /** * Execute a native query plan based on given input Arrow arrays. @@ -126,22 +123,4 @@ class Native extends NativeBase { * the size of the array. */ @native def sortRowPartitionsNative(addr: Long, size: Long): Unit - - /** - * Create a task-level datafusion memory pool. - * @param memoryLimit - * the memory limit of the memory pool. - * @param memoryFraction - * the fraction of reservable memory in the memory pool. - * @return - * the address of the memory pool. - */ - @native def createTaskMemoryPool(memoryLimit: Long, memoryFraction: Double): Long - - /** - * Release the task-level datafusion memory pool. - * @param addr - * the address of the memory pool. - */ - @native def releaseTaskMemoryPool(addr: Long): Unit }