Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a spark.comet.exec.memoryPool configuration for experimenting with various datafusion memory pool setups. #1021

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,15 @@ object CometConf extends ShimCometConf {
.doubleConf
.createWithDefault(0.7)

val COMET_EXEC_MEMORY_POOL_TYPE: ConfigEntry[String] = conf("spark.comet.exec.memoryPool")
.doc(
"The type of memory pool to be used for Comet native execution. " +
"Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', " +
"'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, " +
"this config is 'greedy_task_shared'.")
.stringConf
.createWithDefault("greedy_task_shared")

val COMET_PARQUET_ENABLE_DIRECT_BUFFER: ConfigEntry[Boolean] = conf(
"spark.comet.parquet.enable.directBuffer")
.doc("Whether to use Java direct byte buffer when reading Parquet. By default, this is false")
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Comet provides the following configuration settings.
| spark.comet.exec.hashJoin.enabled | Whether to enable hashJoin by default. | true |
| spark.comet.exec.localLimit.enabled | Whether to enable localLimit by default. | true |
| spark.comet.exec.memoryFraction | The fraction of memory from Comet memory overhead that the native memory manager can use for execution. The purpose of this config is to set aside memory for untracked data structures, as well as imprecise size estimation during memory acquisition. Default value is 0.7. | 0.7 |
| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, this config is 'greedy_task_shared'. | greedy_task_shared |
| spark.comet.exec.project.enabled | Whether to enable project by default. | true |
| spark.comet.exec.shuffle.codec | The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. | zstd |
| spark.comet.exec.shuffle.enabled | Whether to enable Comet native shuffle. Note that this requires setting 'spark.shuffle.manager' to 'org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager'. 'spark.shuffle.manager' must be set before starting the Spark application and cannot be changed during the application. | true |
Expand Down
248 changes: 211 additions & 37 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@

//! Define JNI APIs which can be called from Java/Scala.

use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
use crate::{
errors::{try_unwrap_or_throw, CometError, CometResult},
execution::{
datafusion::planner::PhysicalPlanner, metrics::utils::update_comet_metric,
serde::to_arrow_datatype, shuffle::row::process_sorted_row_partition, sort::RdxSort,
},
jvm_bridge::{jni_new_global_ref, JVMClasses},
};
use arrow::datatypes::DataType as ArrowDataType;
use arrow_array::RecordBatch;
use datafusion::{
Expand All @@ -27,7 +36,13 @@ use datafusion::{
physical_plan::{display::DisplayableExecutionPlan, ExecutionPlan, SendableRecordBatchStream},
prelude::{SessionConfig, SessionContext},
};
use datafusion_comet_proto::spark_operator::Operator;
use datafusion_common::ScalarValue;
use datafusion_execution::memory_pool::{
FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool,
};
use futures::poll;
use futures::stream::StreamExt;
use jni::{
errors::Result as JNIResult,
objects::{
Expand All @@ -37,34 +52,25 @@ use jni::{
sys::{jbyteArray, jint, jlong, jlongArray},
JNIEnv,
};
use std::{collections::HashMap, sync::Arc, task::Poll};

use super::{serde, utils::SparkArrowConvert, CometMemoryPool};

use crate::{
errors::{try_unwrap_or_throw, CometError, CometResult},
execution::{
datafusion::planner::PhysicalPlanner, metrics::utils::update_comet_metric,
serde::to_arrow_datatype, shuffle::row::process_sorted_row_partition, sort::RdxSort,
},
jvm_bridge::{jni_new_global_ref, JVMClasses},
};
use datafusion_comet_proto::spark_operator::Operator;
use datafusion_common::ScalarValue;
use futures::stream::StreamExt;
use jni::{
objects::GlobalRef,
sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
};
use std::num::NonZeroUsize;
use std::sync::Mutex;
use std::{collections::HashMap, sync::Arc, task::Poll};
use tokio::runtime::Runtime;

use crate::execution::operators::ScanExec;
use log::info;
use once_cell::sync::{Lazy, OnceCell};

/// Comet native execution context. Kept alive across JNI calls.
struct ExecutionContext {
/// The id of the execution context.
pub id: i64,
/// Task attempt id
pub task_attempt_id: i64,
/// The deserialized Spark plan
pub spark_plan: Operator,
/// The DataFusion root operator converted from the `spark_plan`
Expand All @@ -87,6 +93,52 @@ struct ExecutionContext {
pub debug_native: bool,
/// Whether to write native plans with metrics to stdout
pub explain_native: bool,
/// Memory pool config
pub memory_pool_config: MemoryPoolConfig,
}

#[derive(PartialEq, Eq)]
enum MemoryPoolType {
Default,
Unified,
Greedy,
FairSpill,
GreedyTaskShared,
FairSpillTaskShared,
GreedyGlobal,
FairSpillGlobal,
}

struct MemoryPoolConfig {
pool_type: MemoryPoolType,
pool_size: usize,
}

impl MemoryPoolConfig {
fn new(pool_type: MemoryPoolType, pool_size: usize) -> Self {
Self {
pool_type,
pool_size,
}
}
}

/// The per-task memory pools keyed by task attempt id.
static TASK_SHARED_MEMORY_POOLS: Lazy<Mutex<HashMap<i64, PerTaskMemoryPool>>> =
Lazy::new(|| Mutex::new(HashMap::new()));

struct PerTaskMemoryPool {
memory_pool: Arc<dyn MemoryPool>,
num_plans: usize,
}

impl PerTaskMemoryPool {
fn new(memory_pool: Arc<dyn MemoryPool>) -> Self {
Self {
memory_pool,
num_plans: 0,
}
}
}

/// Accept serialized query plan and return the address of the native query plan.
Expand Down Expand Up @@ -138,6 +190,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
.map(String::as_str)
.unwrap_or("10")
.parse::<usize>()?;
let task_attempt_id: i64 = configs.get("task_attempt_id").unwrap().parse()?;

// Use multi-threaded tokio runtime to prevent blocking spawned tasks if any
let runtime = tokio::runtime::Builder::new_multi_thread()
Expand All @@ -160,13 +213,18 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
let task_memory_manager =
Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?);

let memory_pool_config = parse_memory_pool_config(&configs)?;
let memory_pool =
create_memory_pool(&memory_pool_config, task_memory_manager, task_attempt_id);

// We need to keep the session context alive. Some session state like temporary
// dictionaries are stored in session context. If it is dropped, the temporary
// dictionaries will be dropped as well.
let session = prepare_datafusion_session_context(&configs, task_memory_manager)?;
let session = prepare_datafusion_session_context(&configs, memory_pool)?;

let exec_context = Box::new(ExecutionContext {
id,
task_attempt_id,
spark_plan,
root_op: None,
scans: vec![],
Expand All @@ -178,6 +236,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
session_ctx: Arc::new(session),
debug_native,
explain_native,
memory_pool_config,
});

Ok(Box::into_raw(exec_context) as i64)
Expand All @@ -187,7 +246,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
/// Parse Comet configs and configure DataFusion session context.
fn prepare_datafusion_session_context(
conf: &HashMap<String, String>,
comet_task_memory_manager: Arc<GlobalRef>,
memory_pool: Option<Arc<dyn MemoryPool>>,
) -> CometResult<SessionContext> {
// Get the batch size from Comet JVM side
let batch_size = conf
Expand All @@ -199,26 +258,8 @@ fn prepare_datafusion_session_context(

let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs);

// Check if we are using unified memory manager integrated with Spark. Default to false if not
// set.
let use_unified_memory_manager = parse_bool(conf, "use_unified_memory_manager")?;

if use_unified_memory_manager {
// Set Comet memory pool for native
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
rt_config = rt_config.with_memory_pool(Arc::new(memory_pool));
} else {
// Use the memory pool from DF
if conf.contains_key("memory_limit") {
let memory_limit = conf.get("memory_limit").unwrap().parse::<usize>()?;
let memory_fraction = conf
.get("memory_fraction")
.ok_or(CometError::Internal(
"Config 'memory_fraction' is not specified from Comet JVM side".to_string(),
))?
.parse::<f64>()?;
rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction)
}
if let Some(memory_pool) = memory_pool {
rt_config = rt_config.with_memory_pool(memory_pool);
}

// Get Datafusion configuration from Spark Execution context
Expand Down Expand Up @@ -250,6 +291,125 @@ fn prepare_datafusion_session_context(
Ok(session_ctx)
}

fn parse_memory_pool_config(conf: &HashMap<String, String>) -> CometResult<MemoryPoolConfig> {
// Check if we are using unified memory manager integrated with Spark. Default to false if not
// set.
let use_unified_memory_manager = parse_bool(conf, "use_unified_memory_manager")?;

let memory_pool_config = if use_unified_memory_manager {
MemoryPoolConfig::new(MemoryPoolType::Unified, 0)
} else {
// Use the memory pool from DF
if conf.contains_key("memory_limit") || conf.contains_key("memory_limit_per_task") {
let memory_pool_type = conf.get("memory_pool_type").unwrap();
let memory_limit = conf.get("memory_limit").unwrap().parse::<usize>()?;
let memory_fraction = conf
.get("memory_fraction")
.ok_or(CometError::Internal(
"Config 'memory_fraction' is not specified from Comet JVM side".to_string(),
))?
.parse::<f64>()?;
let pool_size = (memory_limit as f64 * memory_fraction) as usize;

let memory_limit_per_task = conf.get("memory_limit_per_task").unwrap().parse::<usize>();
let pool_size_per_task =
memory_limit_per_task.map(|v| (v as f64 * memory_fraction) as usize);
match memory_pool_type.as_str() {
"fair_spill_task_shared" => {
MemoryPoolConfig::new(MemoryPoolType::FairSpillTaskShared, pool_size_per_task?)
}
"greedy_task_shared" => {
MemoryPoolConfig::new(MemoryPoolType::GreedyTaskShared, pool_size_per_task?)
}
"fair_spill_global" => {
MemoryPoolConfig::new(MemoryPoolType::FairSpillGlobal, pool_size)
}
"greedy_global" => MemoryPoolConfig::new(MemoryPoolType::GreedyGlobal, pool_size),
"fair_spill" => {
MemoryPoolConfig::new(MemoryPoolType::FairSpill, pool_size_per_task?)
}
"greedy" => MemoryPoolConfig::new(MemoryPoolType::Greedy, pool_size_per_task?),
_ => {
return Err(CometError::Config(format!(
"Unsupported memory pool type: {}",
memory_pool_type
)))
}
}
} else {
MemoryPoolConfig::new(MemoryPoolType::Default, 0)
}
};
Ok(memory_pool_config)
}

fn create_memory_pool(
memory_pool_config: &MemoryPoolConfig,
comet_task_memory_manager: Arc<GlobalRef>,
task_attempt_id: i64,
) -> Option<Arc<dyn MemoryPool>> {
match memory_pool_config.pool_type {
MemoryPoolType::Unified => {
// Set Comet memory pool for native
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
Some(Arc::new(memory_pool))
}
MemoryPoolType::Greedy => Some(Arc::new(TrackConsumersPool::new(
GreedyMemoryPool::new(memory_pool_config.pool_size),
NonZeroUsize::new(10).unwrap(),
))),
MemoryPoolType::FairSpill => Some(Arc::new(TrackConsumersPool::new(
FairSpillPool::new(memory_pool_config.pool_size),
NonZeroUsize::new(10).unwrap(),
))),
MemoryPoolType::GreedyGlobal => {
static GLOBAL_MEMORY_POOL_GREEDY: OnceCell<Arc<dyn MemoryPool>> = OnceCell::new();
let memory_pool = GLOBAL_MEMORY_POOL_GREEDY.get_or_init(|| {
Arc::new(TrackConsumersPool::new(
GreedyMemoryPool::new(memory_pool_config.pool_size),
NonZeroUsize::new(10).unwrap(),
))
});
Some(Arc::clone(memory_pool))
}
MemoryPoolType::FairSpillGlobal => {
static GLOBAL_MEMORY_POOL_FAIR: OnceCell<Arc<dyn MemoryPool>> = OnceCell::new();
let memory_pool = GLOBAL_MEMORY_POOL_FAIR.get_or_init(|| {
Arc::new(TrackConsumersPool::new(
FairSpillPool::new(memory_pool_config.pool_size),
NonZeroUsize::new(10).unwrap(),
))
});
Some(Arc::clone(memory_pool))
}
MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared => {
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
let per_task_memory_pool =
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
let pool: Arc<dyn MemoryPool> =
if memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared {
Arc::new(TrackConsumersPool::new(
GreedyMemoryPool::new(memory_pool_config.pool_size),
NonZeroUsize::new(10).unwrap(),
))
} else {
Arc::new(TrackConsumersPool::new(
FairSpillPool::new(memory_pool_config.pool_size),
NonZeroUsize::new(10).unwrap(),
))
};
PerTaskMemoryPool::new(pool)
});
per_task_memory_pool.num_plans += 1;
Some(Arc::clone(&per_task_memory_pool.memory_pool))
}
MemoryPoolType::Default => {
// Use the memory pool from DF
None
}
}
}

fn parse_bool(conf: &HashMap<String, String>, name: &str) -> CometResult<bool> {
conf.get(name)
.map(String::as_str)
Expand Down Expand Up @@ -435,6 +595,20 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
) {
try_unwrap_or_throw(&e, |_| unsafe {
let execution_context = get_execution_context(exec_context);
if execution_context.memory_pool_config.pool_type == MemoryPoolType::FairSpillTaskShared {
// Decrement the number of native plans using the per-task shared memory pool, and
// remove the memory pool if the released native plan is the last native plan using it.
let task_attempt_id = execution_context.task_attempt_id;
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
if let Some(per_task_memory_pool) = memory_pool_map.get_mut(&task_attempt_id) {
per_task_memory_pool.num_plans -= 1;
if per_task_memory_pool.num_plans == 0 {
// Drop the memory pool from the per-task memory pool map if there are no
// more native plans using it.
memory_pool_map.remove(&task_attempt_id);
}
}
}
let _: Box<ExecutionContext> = Box::from_raw(execution_context);
Ok(())
})
Expand Down
Loading
Loading