diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index c6497d7c9..051181493 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -31,7 +31,7 @@ use sqlparser_derive::{Visit, VisitMut}; use crate::ast::value::escape_single_quote_string; use crate::ast::{ display_comma_separated, display_separated, DataType, Expr, Ident, MySQLColumnPosition, - ObjectName, OrderByExpr, ProjectionSelect, SequenceOptions, SqlOption, Value, + ObjectName, OrderByExpr, ProjectionSelect, SequenceOptions, SqlOption, Tag, Value, }; use crate::tokenizer::Token; @@ -1058,18 +1058,17 @@ impl fmt::Display for ColumnOptionDef { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct IdentityProperty { - pub format: IdentityPropertyCommand, - pub parameters: Option, - pub order: Option, +pub enum Identity { + Autoincrement(IdentityProperty), + Identity(IdentityProperty), } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub enum IdentityPropertyCommand { - Autoincrement, - Identity, +pub struct IdentityProperty { + pub parameters: Option, + pub order: Option, } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -1097,16 +1096,17 @@ pub enum IdentityOrder { Noorder, } -impl fmt::Display for IdentityProperty { +impl fmt::Display for Identity { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.format { - IdentityPropertyCommand::Identity => write!(f, "IDENTITY")?, - IdentityPropertyCommand::Autoincrement => write!(f, "AUTOINCREMENT")?, - } - if let Some(parameters) = &self.parameters { + let (command, property) = match self { + Identity::Identity(property) => ("IDENTITY", property), + Identity::Autoincrement(property) => ("AUTOINCREMENT", property), + }; + write!(f, "{command}")?; + if let Some(parameters) = &property.parameters { write!(f, "{parameters}")?; } - if let Some(order) = &self.order { + if let Some(order) = &property.order { write!(f, "{order}")?; } Ok(()) @@ -1135,6 +1135,36 @@ impl fmt::Display for IdentityOrder { } } +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum ColumnPolicy { + MaskingPolicy(ColumnPolicyProperty), + ProjectionPolicy(ColumnPolicyProperty), +} + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct ColumnPolicyProperty { + pub policy_name: Ident, + pub using_columns: Option>, +} + +impl fmt::Display for ColumnPolicy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let (command, property) = match self { + ColumnPolicy::MaskingPolicy(property) => ("MASKING POLICY", property), + ColumnPolicy::ProjectionPolicy(property) => ("PROJECTION POLICY", property), + }; + write!(f, "{command} {}", property.policy_name)?; + if let Some(using_columns) = &property.using_columns { + write!(f, "USING ({})", display_comma_separated(using_columns))?; + } + Ok(()) + } +} + /// `ColumnOption`s are modifiers that follow a column definition in a `CREATE /// TABLE` statement. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -1212,7 +1242,9 @@ pub enum ColumnOption { /// ``` /// [MS SQL Server]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql-identity-property /// [Snowflake]: https://docs.snowflake.com/en/sql-reference/sql/create-table - Identity(IdentityProperty), + Identity(Identity), + Policy(ColumnPolicy), + Tags(Vec), } impl fmt::Display for ColumnOption { @@ -1315,8 +1347,13 @@ impl fmt::Display for ColumnOption { write!(f, "OPTIONS({})", display_comma_separated(options)) } Identity(parameters) => { - write!(f, "{parameters}")?; - Ok(()) + write!(f, "{parameters}") + } + Policy(parameters) => { + write!(f, "{parameters}") + } + Tags(tags) => { + write!(f, "{}", display_comma_separated(tags)) } } } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 6d726df87..2f5a316da 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -40,9 +40,9 @@ pub use self::data_type::{ pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue, Use}; pub use self::ddl::{ AlterColumnOperation, AlterIndexOperation, AlterTableOperation, ClusteredBy, ColumnDef, - ColumnOption, ColumnOptionDef, ConstraintCharacteristics, Deduplicate, DeferrableInitial, - GeneratedAs, GeneratedExpressionMode, IdentityOrder, IdentityParameters, - IdentityParametersFormat, IdentityProperty, IdentityPropertyCommand, IndexOption, IndexType, + ColumnOption, ColumnOptionDef, ColumnPolicy, ColumnPolicyProperty, ConstraintCharacteristics, + Deduplicate, DeferrableInitial, GeneratedAs, GeneratedExpressionMode, Identity, IdentityOrder, + IdentityParameters, IdentityParametersFormat, IdentityProperty, IndexOption, IndexType, KeyOrIndexDisplay, Owner, Partition, ProcedureParam, ReferentialAction, TableConstraint, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeRepresentation, ViewColumnDef, }; diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index e14746d9a..6fcbfe09b 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -22,9 +22,7 @@ use crate::ast::helpers::stmt_data_loading::{ DataLoadingOption, DataLoadingOptionType, DataLoadingOptions, StageLoadSelectItem, StageParamsObject, }; -use crate::ast::{ - CommentDef, Ident, ObjectName, RowAccessPolicy, Statement, Tag, WrappedCollection, -}; +use crate::ast::{CommentDef, Ident, ObjectName, RowAccessPolicy, Statement, WrappedCollection}; use crate::dialect::{Dialect, Precedence}; use crate::keywords::Keyword; use crate::parser::{Parser, ParserError}; @@ -313,16 +311,8 @@ pub fn parse_create_table( builder.with_row_access_policy(Some(RowAccessPolicy::new(policy, columns))) } Keyword::TAG => { - fn parse_tag(parser: &mut Parser) -> Result { - let name = parser.parse_identifier(false)?; - parser.expect_token(&Token::Eq)?; - let value = parser.parse_literal_string()?; - - Ok(Tag::new(name, value)) - } - parser.expect_token(&Token::LParen)?; - let tags = parser.parse_comma_separated(parse_tag)?; + let tags = parser.parse_comma_separated(Parser::parse_tag)?; parser.expect_token(&Token::RParen)?; builder = builder.with_tags(Some(tags)); } diff --git a/src/keywords.rs b/src/keywords.rs index 6db3ed25c..53e9f235e 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -453,6 +453,7 @@ define_keywords!( MACRO, MANAGEDLOCATION, MAP, + MASKING, MATCH, MATCHED, MATCHES, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 068f52c71..3b262373b 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -6168,6 +6168,7 @@ impl<'a> Parser<'a> { && dialect_of!(self is SnowflakeDialect | SQLiteDialect | GenericDialect) { if dialect_of!(self is SnowflakeDialect) { + self.prev_token(); return self.parse_snowflake_autoincrement_or_identity_option_column(); } @@ -6179,6 +6180,7 @@ impl<'a> Parser<'a> { && dialect_of!(self is MsSqlDialect | SnowflakeDialect | GenericDialect) { if dialect_of!(self is SnowflakeDialect) { + self.prev_token(); return self.parse_snowflake_autoincrement_or_identity_option_column(); } @@ -6196,11 +6198,34 @@ impl<'a> Parser<'a> { } else { None }; - Ok(Some(ColumnOption::Identity(IdentityProperty { - format: IdentityPropertyCommand::Identity, - parameters, - order: None, - }))) + Ok(Some(ColumnOption::Identity(Identity::Identity( + IdentityProperty { + parameters, + order: None, + }, + )))) + } else if ((self.parse_keyword(Keyword::WITH) + && self + .parse_one_of_keywords(&[Keyword::MASKING, Keyword::PROJECTION]) + .is_some()) + || self + .parse_one_of_keywords(&[Keyword::MASKING, Keyword::PROJECTION]) + .is_some()) + && dialect_of!(self is SnowflakeDialect | GenericDialect) + { + self.prev_token(); + let Some(policy) = self.parse_snowflake_column_policy()? else { + return Ok(None); + }; + Ok(Some(ColumnOption::Policy(policy))) + } else if self.parse_keywords(&[Keyword::TAG]) + && dialect_of!(self is SnowflakeDialect | GenericDialect) + { + self.expect_token(&Token::LParen)?; + let tags = self.parse_comma_separated(Self::parse_tag)?; + self.expect_token(&Token::RParen)?; + + Ok(Some(ColumnOption::Tags(tags))) } else { Ok(None) } @@ -6209,12 +6234,11 @@ impl<'a> Parser<'a> { fn parse_snowflake_autoincrement_or_identity_option_column( &mut self, ) -> Result, ParserError> { - self.prev_token(); - let format = match self.parse_one_of_keywords(&[Keyword::IDENTITY, Keyword::AUTOINCREMENT]) - { - Some(Keyword::IDENTITY) => IdentityPropertyCommand::Identity, - Some(Keyword::AUTOINCREMENT) => IdentityPropertyCommand::Autoincrement, - _ => self.expected("one of IDENTITY or AUTOINCREMENT", self.peek_token())?, + let token_location = self.peek_token(); + let Some(keyword) = + self.parse_one_of_keywords(&[Keyword::IDENTITY, Keyword::AUTOINCREMENT]) + else { + return self.expected("IDENTITY or AUTOINCREMENT", token_location.clone()); }; let parameters = if self.consume_token(&Token::LParen) { let seed = self.parse_number()?; @@ -6245,11 +6269,56 @@ impl<'a> Parser<'a> { Some(Keyword::NOORDER) => Some(IdentityOrder::Noorder), _ => None, }; - Ok(Some(ColumnOption::Identity(IdentityProperty { - format, - parameters, - order, - }))) + let property = IdentityProperty { parameters, order }; + let identity = match keyword { + Keyword::IDENTITY => Identity::Identity(property), + Keyword::AUTOINCREMENT => Identity::Autoincrement(property), + _ => self.expected("IDENTITY or AUTOINCREMENT", token_location)?, + }; + Ok(Some(ColumnOption::Identity(identity))) + } + + fn parse_snowflake_column_policy(&mut self) -> Result, ParserError> { + if self.parse_keywords(&[Keyword::MASKING, Keyword::POLICY]) + && dialect_of!(self is SnowflakeDialect | GenericDialect) + { + let property = self.parse_snowflake_column_policy_property()?; + Ok(Some(ColumnPolicy::MaskingPolicy(property))) + } else if self.parse_keywords(&[Keyword::PROJECTION, Keyword::POLICY]) + && dialect_of!(self is SnowflakeDialect | GenericDialect) + { + let property = self.parse_snowflake_column_policy_property()?; + Ok(Some(ColumnPolicy::ProjectionPolicy(property))) + } else { + Ok(None) + } + } + + fn parse_snowflake_column_policy_property( + &mut self, + ) -> Result { + let policy_name = self.parse_identifier(false)?; + let using_columns = if self.parse_keyword(Keyword::USING) { + self.expect_token(&Token::LParen)?; + let columns = self.parse_comma_separated(Self::parse_identifier)?; + self.expect_token(&Token::RParen)?; + Some(columns) + } else { + None + }; + + Ok(ColumnPolicyProperty { + policy_name, + using_columns, + }) + } + + pub fn parse_tag(&mut self) -> Result { + let name = self.parse_identifier(false)?; + self.expect_token(&Token::Eq)?; + let value = self.parse_literal_string()?; + + Ok(Tag::new(name, value)) } fn parse_optional_column_option_generated( diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 13e76650f..c0671b536 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -921,11 +921,10 @@ fn parse_create_table_with_identity_column() { vec![ ColumnOptionDef { name: None, - option: ColumnOption::Identity(IdentityProperty { - format: IdentityPropertyCommand::Identity, + option: ColumnOption::Identity(Identity::Identity(IdentityProperty { parameters: None, order: None, - }), + })), }, ColumnOptionDef { name: None, @@ -949,8 +948,7 @@ fn parse_create_table_with_identity_column() { order: None, }), #[cfg(feature = "bigdecimal")] - option: ColumnOption::Identity(IdentityProperty { - format: IdentityPropertyCommand::Identity, + option: ColumnOption::Identity(Identity::Identity(IdentityProperty { parameters: Some(IdentityParameters { format: IdentityParametersFormat::FunctionCall, seed: Expr::Value(Value::Number( @@ -963,7 +961,7 @@ fn parse_create_table_with_identity_column() { )), }), order: None, - }), + })), }, ColumnOptionDef { name: None, diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index ba2dbb000..d3632e7a6 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -539,6 +539,27 @@ fn test_snowflake_create_table_with_autoincrement_columns() { snowflake().verified_stmt(sql); } +#[test] +fn test_snowflake_create_table_with_collated_columns() { + snowflake().verified_stmt("CREATE TABLE my_table (a TEXT COLLATE 'de_DE')"); +} + +#[test] +fn test_snowflake_create_table_with_masking_policy() { + // let sql = concat!( + // "CREATE TABLE my_table (", + // "a INT AUTOINCREMENT ORDER WITH MASKING POLICY masking_policy_name", + // ")", + // ); + // snowflake().verified_stmt(sql); + let sql = concat!( + "CREATE TABLE my_table (", + "a INT AUTOINCREMENT ORDER MASKING POLICY masking_policy_name USING (a, b)", + ")", + ); + snowflake().verified_stmt(sql); +} + #[test] fn parse_sf_create_or_replace_view_with_comment_missing_equal() { assert!(snowflake_and_generic()