Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement datafusion TableProviderFactory #162

Merged
merged 9 commits into from
Oct 14, 2024
168 changes: 141 additions & 27 deletions crates/datafusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
*/

use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use std::thread;

use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion::catalog::Session;
use datafusion::catalog::{Session, TableProviderFactory};
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder;
Expand All @@ -35,7 +36,7 @@
use datafusion_common::DFSchema;
use datafusion_common::DataFusionError::Execution;
use datafusion_common::Result;
use datafusion_expr::{Expr, TableType};
use datafusion_expr::{CreateExternalTable, Expr, TableType};
use datafusion_physical_expr::create_physical_expr;

use hudi_core::config::read::HudiReadConfig::InputPartitions;
Expand Down Expand Up @@ -150,14 +151,63 @@
}
}

pub struct HudiTableFactory {}

impl Default for HudiTableFactory {
fn default() -> Self {
Self::new()

Check warning on line 158 in crates/datafusion/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/datafusion/src/lib.rs#L157-L158

Added lines #L157 - L158 were not covered by tests
}
}

impl HudiTableFactory {
pub fn new() -> Self {
Self {}
}

fn resolve_options(
state: &dyn Session,
cmd: &CreateExternalTable,
) -> Result<HashMap<String, String>> {
let mut options: HashMap<_, _> = state
.config_options()
.entries()
.iter()
.filter_map(|e| {
let value = e.value.as_ref().filter(|v| !v.is_empty())?;
Some((e.key.clone(), value.clone()))
})
.collect();

// options from the command take precedence
options.extend(cmd.options.iter().map(|(k, v)| (k.clone(), v.clone())));

Ok(options)
}
}

#[async_trait]
impl TableProviderFactory for HudiTableFactory {
async fn create(
&self,
state: &dyn Session,
cmd: &CreateExternalTable,
) -> Result<Arc<dyn TableProvider>> {
let options = HudiTableFactory::resolve_options(state, cmd)?;
let base_uri = cmd.location.as_str();
let table_provider = HudiDataSource::new_with_options(base_uri, options).await?;
Ok(Arc::new(table_provider))
}
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::{DataFusionError, ScalarValue};
use std::fs::canonicalize;
use std::path::Path;
use std::sync::Arc;

use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::ScalarValue;
use url::Url;

use hudi_core::config::read::HudiReadConfig::InputPartitions;
Expand All @@ -170,6 +220,7 @@
use utils::{get_bool_column, get_i32_column, get_str_column};

use crate::HudiDataSource;
use crate::HudiTableFactory;

#[tokio::test]
async fn get_default_input_partitions() {
Expand All @@ -180,22 +231,81 @@
assert_eq!(hudi.get_input_partitions(), 0)
}

async fn prepare_session_context(
async fn register_test_table_with_session<I, K, V>(
test_table: &TestTable,
options: Vec<(&str, &str)>,
) -> SessionContext {
options: I,
use_sql: bool,
) -> Result<SessionContext, DataFusionError>
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: Into<String>,
{
let ctx = create_test_session().await;
if use_sql {
let create_table_sql = format!(
"CREATE EXTERNAL TABLE {} STORED AS HUDI LOCATION '{}' {}",
test_table.as_ref(),
test_table.path(),
concat_as_sql_options(options)
);
ctx.sql(create_table_sql.as_str()).await?;
} else {
let base_url = test_table.url();
let hudi = HudiDataSource::new_with_options(base_url.as_str(), options).await?;
ctx.register_table(test_table.as_ref(), Arc::new(hudi))?;
}
Ok(ctx)
}

async fn create_test_session() -> SessionContext {
let config = SessionConfig::new().set(
"datafusion.sql_parser.enable_ident_normalization",
&ScalarValue::from(false),
);
let ctx = SessionContext::new_with_config(config);
let base_url = test_table.url();
let hudi = HudiDataSource::new_with_options(base_url.as_str(), options)
.await
.unwrap();
ctx.register_table(test_table.as_ref(), Arc::new(hudi))
.unwrap();
ctx
let mut session_state = SessionStateBuilder::new()
.with_default_features()
.with_config(config)
.build();
session_state
.table_factories_mut()
.insert("HUDI".to_string(), Arc::new(HudiTableFactory::new()));

SessionContext::new_with_state(session_state)
}

fn concat_as_sql_options<I, K, V>(options: I) -> String
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: Into<String>,
{
let kv_pairs: Vec<String> = options
.into_iter()
.map(|(k, v)| format!("'{}' '{}'", k.as_ref(), v.into()))
.collect();

if kv_pairs.is_empty() {
String::new()
} else {
format!("OPTIONS ({})", kv_pairs.join(", "))
}
}

#[tokio::test]
async fn test_create_table_with_unknown_format() {
let test_table = V6Nonpartitioned;
let invalid_format = "UNKNOWN_FORMAT";
let create_table_sql = format!(
"CREATE EXTERNAL TABLE {} STORED AS {} LOCATION '{}'",
test_table.as_ref(),
invalid_format,
test_table.path()
);

let ctx = create_test_session().await;
let result = ctx.sql(create_table_sql.as_str()).await;
assert!(result.is_err());
}

async fn verify_plan(
Expand Down Expand Up @@ -236,16 +346,18 @@

#[tokio::test]
async fn datafusion_read_hudi_table() {
for (test_table, planned_input_partitions) in &[
(V6ComplexkeygenHivestyle, 2),
(V6Nonpartitioned, 1),
(V6SimplekeygenNonhivestyle, 2),
(V6SimplekeygenHivestyleNoMetafields, 2),
(V6TimebasedkeygenNonhivestyle, 2),
for (test_table, use_sql, planned_input_partitions) in &[
(V6ComplexkeygenHivestyle, true, 2),
(V6Nonpartitioned, true, 1),
(V6SimplekeygenNonhivestyle, false, 2),
(V6SimplekeygenHivestyleNoMetafields, true, 2),
(V6TimebasedkeygenNonhivestyle, false, 2),
] {
println!(">>> testing for {}", test_table.as_ref());
let options = vec![(InputPartitions.as_ref(), "2")];
let ctx = prepare_session_context(test_table, options).await;
let options = [(InputPartitions, "2")];
let ctx = register_test_table_with_session(test_table, options, *use_sql)
.await
.unwrap();

let sql = format!(
r#"
Expand Down Expand Up @@ -275,12 +387,14 @@

#[tokio::test]
async fn datafusion_read_hudi_table_with_replacecommits() {
for (test_table, planned_input_partitions) in
&[(V6SimplekeygenNonhivestyleOverwritetable, 1)]
for (test_table, use_sql, planned_input_partitions) in
&[(V6SimplekeygenNonhivestyleOverwritetable, true, 1)]
{
println!(">>> testing for {}", test_table.as_ref());
let ctx =
prepare_session_context(test_table, vec![(InputPartitions.as_ref(), "2")]).await;
register_test_table_with_session(test_table, [(InputPartitions, "2")], *use_sql)
.await
.unwrap();

let sql = format!(
r#"
Expand Down
Loading