From 3b95e3b9df82ad33c50c03bdfab3d25d0243f0fc Mon Sep 17 00:00:00 2001 From: "aleksei.p" Date: Wed, 2 Oct 2024 16:49:43 +0200 Subject: [PATCH] update --- src/ast/ddl.rs | 48 ++++-- src/ast/mod.rs | 6 +- src/parser/mod.rs | 17 +- tests/sqlparser_mssql.rs | 15 +- tests/sqlparser_snowflake.rs | 304 ++++++++++++++++++++++++++++++++--- 5 files changed, 333 insertions(+), 57 deletions(-) diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 051181493..801c6a6b4 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -1067,25 +1067,24 @@ pub enum Identity { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct IdentityProperty { - pub parameters: Option, + pub parameters: Option, pub order: Option, } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct IdentityParameters { - pub format: IdentityParametersFormat, - pub seed: Expr, - pub increment: Expr, +pub enum IdentityFormat { + FunctionCall(IdentityParameters), + StartAndIncrement(IdentityParameters), } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub enum IdentityParametersFormat { - FunctionCall, - StartIncrement, +pub struct IdentityParameters { + pub seed: Expr, + pub increment: Expr, } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -1113,14 +1112,18 @@ impl fmt::Display for Identity { } } -impl fmt::Display for IdentityParameters { +impl fmt::Display for IdentityFormat { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.format { - IdentityParametersFormat::FunctionCall => { - write!(f, "({}, {})", self.seed, self.increment) + match self { + IdentityFormat::FunctionCall(parameters) => { + write!(f, "({}, {})", parameters.seed, parameters.increment) } - IdentityParametersFormat::StartIncrement => { - write!(f, " START {} INCREMENT {}", self.seed, self.increment) + IdentityFormat::StartAndIncrement(parameters) => { + write!( + f, + " START {} INCREMENT {}", + parameters.seed, parameters.increment + ) } } } @@ -1157,7 +1160,7 @@ impl fmt::Display for ColumnPolicy { ColumnPolicy::MaskingPolicy(property) => ("MASKING POLICY", property), ColumnPolicy::ProjectionPolicy(property) => ("PROJECTION POLICY", property), }; - write!(f, "{command} {}", property.policy_name)?; + write!(f, "WITH {command} {}", property.policy_name)?; if let Some(using_columns) = &property.using_columns { write!(f, "USING ({})", display_comma_separated(using_columns))?; } @@ -1243,7 +1246,20 @@ pub enum ColumnOption { /// [MS SQL Server]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql-identity-property /// [Snowflake]: https://docs.snowflake.com/en/sql-reference/sql/create-table Identity(Identity), + /// Snowflake specific: an option of specifying security masking or projection policy to set on a column. + /// Syntax: + /// ```sql + /// [ WITH ] MASKING POLICY [ USING ( , , ... ) ] + /// [ WITH ] PROJECTION POLICY + /// ``` + /// [Snowflake]: https://docs.snowflake.com/en/sql-reference/sql/create-table Policy(ColumnPolicy), + /// Snowflake specific: Specifies the tag name and the tag string value. + /// Syntax: + /// ```sql + /// [ WITH ] TAG ( = '' [ , = '' , ... ] ) + /// ``` + /// [Snowflake]: https://docs.snowflake.com/en/sql-reference/sql/create-table Tags(Vec), } @@ -1353,7 +1369,7 @@ impl fmt::Display for ColumnOption { write!(f, "{parameters}") } Tags(tags) => { - write!(f, "{}", display_comma_separated(tags)) + write!(f, "WITH TAG ({})", display_comma_separated(tags)) } } } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 2f5a316da..a328609e3 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -41,9 +41,9 @@ pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue, pub use self::ddl::{ AlterColumnOperation, AlterIndexOperation, AlterTableOperation, ClusteredBy, ColumnDef, ColumnOption, ColumnOptionDef, ColumnPolicy, ColumnPolicyProperty, ConstraintCharacteristics, - Deduplicate, DeferrableInitial, GeneratedAs, GeneratedExpressionMode, Identity, IdentityOrder, - IdentityParameters, IdentityParametersFormat, IdentityProperty, IndexOption, IndexType, - KeyOrIndexDisplay, Owner, Partition, ProcedureParam, ReferentialAction, TableConstraint, + Deduplicate, DeferrableInitial, GeneratedAs, GeneratedExpressionMode, Identity, IdentityFormat, + IdentityOrder, IdentityParameters, IdentityProperty, IndexOption, IndexType, KeyOrIndexDisplay, + Owner, Partition, ProcedureParam, ReferentialAction, TableConstraint, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeRepresentation, ViewColumnDef, }; pub use self::dml::{CreateIndex, CreateTable, Delete, Insert}; diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 3b262373b..580ed6e5b 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -6190,11 +6190,10 @@ impl<'a> Parser<'a> { let increment = self.parse_number()?; self.expect_token(&Token::RParen)?; - Some(IdentityParameters { - format: IdentityParametersFormat::FunctionCall, + Some(IdentityFormat::FunctionCall(IdentityParameters { seed, increment, - }) + })) } else { None }; @@ -6246,21 +6245,19 @@ impl<'a> Parser<'a> { let increment = self.parse_number()?; self.expect_token(&Token::RParen)?; - Some(IdentityParameters { - format: IdentityParametersFormat::FunctionCall, + Some(IdentityFormat::FunctionCall(IdentityParameters { seed, increment, - }) + })) } else if self.parse_keyword(Keyword::START) { let seed = self.parse_number()?; self.expect_keyword(Keyword::INCREMENT)?; let increment = self.parse_number()?; - Some(IdentityParameters { - format: IdentityParametersFormat::StartIncrement, + Some(IdentityFormat::StartAndIncrement(IdentityParameters { seed, increment, - }) + })) } else { None }; @@ -6300,7 +6297,7 @@ impl<'a> Parser<'a> { let policy_name = self.parse_identifier(false)?; let using_columns = if self.parse_keyword(Keyword::USING) { self.expect_token(&Token::LParen)?; - let columns = self.parse_comma_separated(Self::parse_identifier)?; + let columns = self.parse_comma_separated(|p| p.parse_identifier(false))?; self.expect_token(&Token::RParen)?; Some(columns) } else { diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index c0671b536..a2a45dfdb 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -938,19 +938,16 @@ fn parse_create_table_with_identity_column() { ColumnOptionDef { name: None, #[cfg(not(feature = "bigdecimal"))] - option: ColumnOption::Identity(IdentityProperty { - format: IdentityPropertyCommand::Identity, - parameters: Some(IdentityParameters { - format: IdentityParametersFormat::FunctionCall, + option: ColumnOption::Identity(Identity::Identity(IdentityProperty { + parameters: Some(IdentityFormat::FunctionCall(IdentityParameters { seed: Expr::Value(Value::Number("1".to_string(), false)), increment: Expr::Value(Value::Number("1".to_string(), false)), - }), + })), order: None, - }), + })), #[cfg(feature = "bigdecimal")] option: ColumnOption::Identity(Identity::Identity(IdentityProperty { - parameters: Some(IdentityParameters { - format: IdentityParametersFormat::FunctionCall, + parameters: Some(IdentityFormat::FunctionCall(IdentityParameters { seed: Expr::Value(Value::Number( bigdecimal::BigDecimal::from(1), false, @@ -959,7 +956,7 @@ fn parse_create_table_with_identity_column() { bigdecimal::BigDecimal::from(1), false, )), - }), + })), order: None, })), }, diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index d3632e7a6..727e7d100 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -530,34 +530,300 @@ fn test_snowflake_create_table_with_autoincrement_columns() { let sql = concat!( "CREATE TABLE my_table (", "a INT AUTOINCREMENT ORDER, ", - "b INT AUTOINCREMENT(100, -1) NOORDER, ", + "b INT AUTOINCREMENT(100, 1) NOORDER, ", "c INT IDENTITY, ", - "d INT IDENTITY START 100 INCREMENT -1 ORDER, ", - "a INT AUTOINCREMENT ORDER NOT NULL", + "d INT IDENTITY START 100 INCREMENT 1 ORDER", ")" ); - snowflake().verified_stmt(sql); + // it is a snowflake specific options (AUTOINCREMENT/IDENTITY) + match snowflake().verified_stmt(sql) { + Statement::CreateTable(CreateTable { columns, .. }) => { + assert_eq!( + columns, + vec![ + ColumnDef { + name: "a".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Identity(Identity::Autoincrement( + IdentityProperty { + parameters: None, + order: Some(IdentityOrder::Order), + } + )) + }] + }, + ColumnDef { + name: "b".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Identity(Identity::Autoincrement( + IdentityProperty { + #[cfg(not(feature = "bigdecimal"))] + parameters: Some(IdentityFormat::FunctionCall( + IdentityParameters { + seed: Expr::Value(Value::Number( + "100".to_string(), + false + )), + increment: Expr::Value(Value::Number( + "1".to_string(), + false + )), + } + )), + #[cfg(feature = "bigdecimal")] + parameters: Some(IdentityFormat::FunctionCall( + IdentityParameters { + seed: Expr::Value(Value::Number( + bigdecimal::BigDecimal::from(100), + false, + )), + increment: Expr::Value(Value::Number( + bigdecimal::BigDecimal::from(1), + false, + )), + } + )), + order: Some(IdentityOrder::Noorder), + } + )) + }] + }, + ColumnDef { + name: "c".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Identity(Identity::Identity(IdentityProperty { + parameters: None, + order: None, + })) + }] + }, + ColumnDef { + name: "d".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Identity(Identity::Identity(IdentityProperty { + #[cfg(not(feature = "bigdecimal"))] + parameters: Some(IdentityFormat::StartAndIncrement( + IdentityParameters { + seed: Expr::Value(Value::Number("100".to_string(), false)), + increment: Expr::Value(Value::Number( + "1".to_string(), + false + )), + } + )), + #[cfg(feature = "bigdecimal")] + parameters: Some(IdentityFormat::StartAndIncrement( + IdentityParameters { + seed: Expr::Value(Value::Number( + bigdecimal::BigDecimal::from(100), + false, + )), + increment: Expr::Value(Value::Number( + bigdecimal::BigDecimal::from(1), + false, + )), + } + )), + order: Some(IdentityOrder::Order), + })) + }] + }, + ] + ); + } + _ => unreachable!(), + } } #[test] -fn test_snowflake_create_table_with_collated_columns() { - snowflake().verified_stmt("CREATE TABLE my_table (a TEXT COLLATE 'de_DE')"); +fn test_snowflake_create_table_with_collated_column() { + match snowflake_and_generic().verified_stmt("CREATE TABLE my_table (a TEXT COLLATE 'de_DE')") { + Statement::CreateTable(CreateTable { columns, .. }) => { + assert_eq!( + columns, + vec![ColumnDef { + name: "a".into(), + data_type: DataType::Text, + collation: Some(ObjectName(vec![Ident::with_quote('\'', "de_DE")])), + options: vec![] + },] + ); + } + _ => unreachable!(), + } } #[test] -fn test_snowflake_create_table_with_masking_policy() { - // let sql = concat!( - // "CREATE TABLE my_table (", - // "a INT AUTOINCREMENT ORDER WITH MASKING POLICY masking_policy_name", - // ")", - // ); - // snowflake().verified_stmt(sql); - let sql = concat!( - "CREATE TABLE my_table (", - "a INT AUTOINCREMENT ORDER MASKING POLICY masking_policy_name USING (a, b)", - ")", - ); - snowflake().verified_stmt(sql); +fn test_snowflake_create_table_with_columns_masking_policy() { + match snowflake_and_generic() + .verified_stmt("CREATE TABLE my_table (a INT WITH MASKING POLICY p)") + { + Statement::CreateTable(CreateTable { columns, .. }) => { + assert_eq!( + columns, + vec![ColumnDef { + name: "a".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Policy(ColumnPolicy::MaskingPolicy( + ColumnPolicyProperty { + policy_name: "p".into(), + using_columns: None, + } + )) + }], + },] + ); + } + _ => unreachable!(), + } + match snowflake_and_generic() + .parse_sql_statements("CREATE TABLE my_table (a INT MASKING POLICY p USING (a, b))") + .unwrap() + .pop() + .unwrap() + { + Statement::CreateTable(CreateTable { columns, .. }) => { + assert_eq!( + columns, + vec![ColumnDef { + name: "a".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Policy(ColumnPolicy::MaskingPolicy( + ColumnPolicyProperty { + policy_name: "p".into(), + using_columns: Some(vec!["a".into(), "b".into()]), + } + )) + }], + },] + ); + } + _ => unreachable!(), + } +} + +#[test] +fn test_snowflake_create_table_with_columns_projection_policy() { + match snowflake_and_generic() + .verified_stmt("CREATE TABLE my_table (a INT WITH PROJECTION POLICY p)") + { + Statement::CreateTable(CreateTable { columns, .. }) => { + assert_eq!( + columns, + vec![ColumnDef { + name: "a".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Policy(ColumnPolicy::ProjectionPolicy( + ColumnPolicyProperty { + policy_name: "p".into(), + using_columns: None, + } + )) + }], + },] + ); + } + _ => unreachable!(), + } + match snowflake_and_generic() + .parse_sql_statements("CREATE TABLE my_table (a INT PROJECTION POLICY p)") + .unwrap() + .pop() + .unwrap() + { + Statement::CreateTable(CreateTable { columns, .. }) => { + assert_eq!( + columns, + vec![ColumnDef { + name: "a".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Policy(ColumnPolicy::ProjectionPolicy( + ColumnPolicyProperty { + policy_name: "p".into(), + using_columns: None, + } + )) + }], + },] + ); + } + _ => unreachable!(), + } +} + +#[test] +fn test_snowflake_create_table_with_columns_tags() { + match snowflake_and_generic() + .verified_stmt("CREATE TABLE my_table (a INT WITH TAG (A='TAG A', B='TAG B'))") + { + Statement::CreateTable(CreateTable { columns, .. }) => { + assert_eq!( + columns, + vec![ColumnDef { + name: "a".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Tags(vec![ + Tag::new("A".into(), "TAG A".into()), + Tag::new("B".into(), "TAG B".into()), + ]), + }], + },] + ); + } + _ => unreachable!(), + } + match snowflake_and_generic() + .parse_sql_statements("CREATE TABLE my_table (a INT TAG (A='TAG A', B='TAG B'))") + .unwrap() + .pop() + .unwrap() + { + Statement::CreateTable(CreateTable { columns, .. }) => { + assert_eq!( + columns, + vec![ColumnDef { + name: "a".into(), + data_type: DataType::Int(None), + collation: None, + options: vec![ColumnOptionDef { + name: None, + option: ColumnOption::Tags(vec![ + Tag::new("A".into(), "TAG A".into()), + Tag::new("B".into(), "TAG B".into()), + ]), + }], + },] + ); + } + _ => unreachable!(), + } } #[test]