diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index f3834c4cd9c5..8f40bac70153 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -110,6 +110,7 @@ doc-comment = "0.3" env_logger = "0.10" parquet-test-utils = { path = "../../parquet-test-utils" } rstest = "0.16.0" +sqlparser = "0.27" test-utils = { path = "../../test-utils" } [[bench]] diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 632ef8d287d2..a80c4b94d999 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; +use parking_lot::RwLock; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; @@ -40,7 +41,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; #[derive(Debug)] pub struct MemTable { schema: SchemaRef, - batches: Vec>, + batches: Arc>>>, } impl MemTable { @@ -53,7 +54,7 @@ impl MemTable { { Ok(Self { schema, - batches: partitions, + batches: Arc::new(RwLock::new(partitions)), }) } else { Err(DataFusionError::Plan( @@ -117,6 +118,11 @@ impl MemTable { } MemTable::try_new(schema.clone(), data) } + + /// Get record batches in MemTable + pub fn get_batches(&self) -> Arc>>> { + self.batches.clone() + } } #[async_trait] @@ -140,8 +146,9 @@ impl TableProvider for MemTable { _filters: &[Expr], _limit: Option, ) -> Result> { + let batches = self.batches.read(); Ok(Arc::new(MemoryExec::try_new( - &self.batches.clone(), + &(*batches).clone(), self.schema(), projection.cloned(), )?)) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 6b6a4fb1f6e6..9115b2efbb47 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -958,6 +958,22 @@ impl SessionContext { } } + /// Return a [`TabelProvider`] for the specified table. + pub fn table_provider<'a>( + &self, + table_ref: impl Into>, + ) -> Result> { + let table_ref = table_ref.into(); + let schema = self.state.read().schema_for_ref(table_ref)?; + match schema.table(table_ref.table()) { + Some(ref provider) => Ok(Arc::clone(provider)), + _ => Err(DataFusionError::Plan(format!( + "No table named '{}'", + table_ref.table() + ))), + } + } + /// Returns the set of available tables in the default catalog and /// schema. /// diff --git a/datafusion/core/tests/sqllogictests/src/error.rs b/datafusion/core/tests/sqllogictests/src/error.rs new file mode 100644 index 000000000000..8ac4821413ff --- /dev/null +++ b/datafusion/core/tests/sqllogictests/src/error.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::DataFusionError; +use sqllogictest::TestError; +use sqlparser::parser::ParserError; +use std::error; +use std::fmt::{Display, Formatter}; + +pub type Result = std::result::Result; + +/// DataFusion sql-logicaltest error +#[derive(Debug)] +pub enum DFSqlLogicTestError { + /// Error from sqllogictest-rs + SqlLogicTest(TestError), + /// Error from datafusion + DataFusion(DataFusionError), + /// Error returned when SQL is syntactically incorrect. + Sql(ParserError), + /// Error returned on a branch that we know it is possible + /// but to which we still have no implementation for. + /// Often, these errors are tracked in our issue tracker. + NotImplemented(String), + /// Error returned from DFSqlLogicTest inner + Internal(String), +} + +impl From for DFSqlLogicTestError { + fn from(value: TestError) -> Self { + DFSqlLogicTestError::SqlLogicTest(value) + } +} + +impl From for DFSqlLogicTestError { + fn from(value: DataFusionError) -> Self { + DFSqlLogicTestError::DataFusion(value) + } +} + +impl From for DFSqlLogicTestError { + fn from(value: ParserError) -> Self { + DFSqlLogicTestError::Sql(value) + } +} + +impl Display for DFSqlLogicTestError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DFSqlLogicTestError::SqlLogicTest(error) => write!( + f, + "SqlLogicTest error(from sqllogictest-rs crate): {}", + error + ), + DFSqlLogicTestError::DataFusion(error) => { + write!(f, "DataFusion error: {}", error) + } + DFSqlLogicTestError::Sql(error) => write!(f, "SQL Parser error: {}", error), + DFSqlLogicTestError::NotImplemented(error) => { + write!(f, "This feature is not implemented yet: {}", error) + } + DFSqlLogicTestError::Internal(error) => { + write!(f, "Internal error: {}", error) + } + } + } +} + +impl error::Error for DFSqlLogicTestError {} diff --git a/datafusion/core/tests/sqllogictests/src/insert/mod.rs b/datafusion/core/tests/sqllogictests/src/insert/mod.rs new file mode 100644 index 000000000000..100fa1184e7d --- /dev/null +++ b/datafusion/core/tests/sqllogictests/src/insert/mod.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod util; + +use crate::error::{DFSqlLogicTestError, Result}; +use crate::insert::util::LogicTestContextProvider; +use datafusion::datasource::MemTable; +use datafusion::prelude::SessionContext; +use datafusion_common::{DFSchema, DataFusionError}; +use datafusion_expr::Expr as DFExpr; +use datafusion_sql::parser::{DFParser, Statement}; +use datafusion_sql::planner::SqlToRel; +use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement}; +use std::collections::HashMap; + +pub async fn insert(ctx: &SessionContext, sql: String) -> Result { + // First, use sqlparser to get table name and insert values + let mut table_name = "".to_string(); + let mut insert_values: Vec> = vec![]; + if let Statement::Statement(statement) = &DFParser::parse_sql(&sql)?[0] { + if let SQLStatement::Insert { + table_name: name, + source, + .. + } = &**statement + { + // Todo: check columns match table schema + table_name = name.to_string(); + match &*source.body { + SetExpr::Values(values) => { + insert_values = values.0.clone(); + } + _ => { + return Err(DFSqlLogicTestError::NotImplemented( + "Only support insert values".to_string(), + )); + } + } + } + } else { + return Err(DFSqlLogicTestError::Internal(format!( + "{:?} not an insert statement", + sql + ))); + } + + // Second, get table by table name + // Here we assume table must be in memory table. + let table_provider = ctx.table_provider(table_name.as_str())?; + let table_batches = table_provider + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DFSqlLogicTestError::NotImplemented( + "only support use memory table in logictest".to_string(), + ) + })? + .get_batches(); + + // Third, transfer insert values to `RecordBatch` + // Attention: schema info can be ignored. (insert values don't contain schema info) + let sql_to_rel = SqlToRel::new(&LogicTestContextProvider {}); + let mut insert_batches = Vec::with_capacity(insert_values.len()); + for row in insert_values.into_iter() { + let logical_exprs = row + .into_iter() + .map(|expr| { + sql_to_rel.sql_to_rex(expr, &DFSchema::empty(), &mut HashMap::new()) + }) + .collect::, DataFusionError>>()?; + // Directly use `select` to get `RecordBatch` + let dataframe = ctx.read_empty()?; + insert_batches.push(dataframe.select(logical_exprs)?.collect().await?) + } + + // Final, append the `RecordBatch` to memtable's batches + let mut table_batches = table_batches.write(); + table_batches.extend(insert_batches); + + Ok("".to_string()) +} diff --git a/datafusion/core/tests/sqllogictests/src/insert/util.rs b/datafusion/core/tests/sqllogictests/src/insert/util.rs new file mode 100644 index 000000000000..03dbb72995ff --- /dev/null +++ b/datafusion/core/tests/sqllogictests/src/insert/util.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::{ScalarValue, TableReference}; +use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource}; +use datafusion_sql::planner::ContextProvider; +use std::sync::Arc; + +pub struct LogicTestContextProvider {} + +// Only a mock, don't need to implement +impl ContextProvider for LogicTestContextProvider { + fn get_table_provider( + &self, + _name: TableReference, + ) -> datafusion_common::Result> { + todo!() + } + + fn get_function_meta(&self, _name: &str) -> Option> { + todo!() + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + todo!() + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + todo!() + } + + fn get_config_option(&self, _variable: &str) -> Option { + todo!() + } +} diff --git a/datafusion/core/tests/sqllogictests/src/main.rs b/datafusion/core/tests/sqllogictests/src/main.rs index fc27773c010d..0efac489f3e6 100644 --- a/datafusion/core/tests/sqllogictests/src/main.rs +++ b/datafusion/core/tests/sqllogictests/src/main.rs @@ -22,9 +22,11 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use std::path::Path; use std::time::Duration; -use sqllogictest::TestError; -pub type Result = std::result::Result; +use crate::error::{DFSqlLogicTestError, Result}; +use crate::insert::insert; +mod error; +mod insert; mod setup; mod utils; @@ -37,7 +39,7 @@ pub struct DataFusion { #[async_trait] impl sqllogictest::AsyncDB for DataFusion { - type Error = TestError; + type Error = DFSqlLogicTestError; async fn run(&mut self, sql: &str) -> Result { println!("[{}] Running query: \"{}\"", self.file_name, sql); @@ -138,7 +140,14 @@ fn format_batches(batches: &[RecordBatch]) -> Result { } async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result { - let df = ctx.sql(&sql.into()).await.unwrap(); + let sql = sql.into(); + // Check if the sql is `insert` + if sql.trim_start().to_lowercase().starts_with("insert") { + // Process the insert statement + insert(ctx, sql).await?; + return Ok("".to_string()); + } + let df = ctx.sql(sql.as_str()).await.unwrap(); let results: Vec = df.collect().await.unwrap(); let formatted_batches = format_batches(&results)?; Ok(formatted_batches) diff --git a/datafusion/core/tests/sqllogictests/test_files/insert.slt b/datafusion/core/tests/sqllogictests/test_files/insert.slt new file mode 100644 index 000000000000..0927b3777ddc --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/insert.slt @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE users AS VALUES(1,2),(2,3); + +query II rowsort +select * from users; +---- +1 2 +2 3 + +statement ok +insert into users values(2, 4); + +query II rowsort +select * from users; +---- +1 2 +2 3 +2 4 + +statement ok +insert into users values(1 + 10, 20); + +query II rowsort +select * from users; +---- +1 2 +2 3 +2 4 +11 20 + +# Test insert into a undefined table +statement error +insert into user values(1, 20);