From 8a534c0e279dde2fad06e55423fba9cbffcca0a2 Mon Sep 17 00:00:00 2001 From: hulk Date: Thu, 26 Sep 2024 01:32:04 +0800 Subject: [PATCH] Implements CREATE POLICY syntax for PostgreSQL (#1440) --- src/ast/mod.rs | 85 +++++++++++++++++++++++++++ src/keywords.rs | 2 + src/parser/mod.rs | 118 +++++++++++++++++++++++++++++++++----- tests/sqlparser_common.rs | 102 ++++++++++++++++++++++++++++++++ 4 files changed, 292 insertions(+), 15 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 9b7a66650..83646d298 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2135,6 +2135,35 @@ pub enum FromTable { WithoutKeyword(Vec), } +/// Policy type for a `CREATE POLICY` statement. +/// ```sql +/// AS [ PERMISSIVE | RESTRICTIVE ] +/// ``` +/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createpolicy.html) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum CreatePolicyType { + Permissive, + Restrictive, +} + +/// Policy command for a `CREATE POLICY` statement. +/// ```sql +/// FOR [ALL | SELECT | INSERT | UPDATE | DELETE] +/// ``` +/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createpolicy.html) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum CreatePolicyCommand { + All, + Select, + Insert, + Update, + Delete, +} + /// A top-level statement (SELECT, INSERT, CREATE, etc.) #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -2375,6 +2404,20 @@ pub enum Statement { options: Vec, }, /// ```sql + /// CREATE POLICY + /// ``` + /// See [PostgreSQL](https://www.postgresql.org/docs/current/sql-createpolicy.html) + CreatePolicy { + name: Ident, + #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] + table_name: ObjectName, + policy_type: Option, + command: Option, + to: Option>, + using: Option, + with_check: Option, + }, + /// ```sql /// ALTER TABLE /// ``` AlterTable { @@ -4052,6 +4095,48 @@ impl fmt::Display for Statement { write!(f, " )")?; Ok(()) } + Statement::CreatePolicy { + name, + table_name, + policy_type, + command, + to, + using, + with_check, + } => { + write!(f, "CREATE POLICY {name} ON {table_name}")?; + + if let Some(policy_type) = policy_type { + match policy_type { + CreatePolicyType::Permissive => write!(f, " AS PERMISSIVE")?, + CreatePolicyType::Restrictive => write!(f, " AS RESTRICTIVE")?, + } + } + + if let Some(command) = command { + match command { + CreatePolicyCommand::All => write!(f, " FOR ALL")?, + CreatePolicyCommand::Select => write!(f, " FOR SELECT")?, + CreatePolicyCommand::Insert => write!(f, " FOR INSERT")?, + CreatePolicyCommand::Update => write!(f, " FOR UPDATE")?, + CreatePolicyCommand::Delete => write!(f, " FOR DELETE")?, + } + } + + if let Some(to) = to { + write!(f, " TO {}", display_comma_separated(to))?; + } + + if let Some(using) = using { + write!(f, " USING ({using})")?; + } + + if let Some(with_check) = with_check { + write!(f, " WITH CHECK ({with_check})")?; + } + + Ok(()) + } Statement::AlterTable { name, if_exists, diff --git a/src/keywords.rs b/src/keywords.rs index d384062f2..49c6ce20f 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -568,6 +568,7 @@ define_keywords!( PERCENTILE_DISC, PERCENT_RANK, PERIOD, + PERMISSIVE, PERSISTENT, PIVOT, PLACING, @@ -634,6 +635,7 @@ define_keywords!( RESTART, RESTRICT, RESTRICTED, + RESTRICTIVE, RESULT, RESULTSET, RETAIN, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 5d57347cf..4c3f8788d 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -32,6 +32,7 @@ use IsLateral::*; use IsOptional::*; use crate::ast::helpers::stmt_create_table::{CreateTableBuilder, CreateTableConfiguration}; +use crate::ast::Statement::CreatePolicy; use crate::ast::*; use crate::dialect::*; use crate::keywords::{Keyword, ALL_KEYWORDS}; @@ -3569,6 +3570,8 @@ impl<'a> Parser<'a> { } else if self.parse_keyword(Keyword::MATERIALIZED) || self.parse_keyword(Keyword::VIEW) { self.prev_token(); self.parse_create_view(or_replace, temporary) + } else if self.parse_keyword(Keyword::POLICY) { + self.parse_create_policy() } else if self.parse_keyword(Keyword::EXTERNAL) { self.parse_create_external_table(or_replace) } else if self.parse_keyword(Keyword::FUNCTION) { @@ -4762,6 +4765,105 @@ impl<'a> Parser<'a> { }) } + pub fn parse_owner(&mut self) -> Result { + let owner = match self.parse_one_of_keywords(&[Keyword::CURRENT_USER, Keyword::CURRENT_ROLE, Keyword::SESSION_USER]) { + Some(Keyword::CURRENT_USER) => Owner::CurrentUser, + Some(Keyword::CURRENT_ROLE) => Owner::CurrentRole, + Some(Keyword::SESSION_USER) => Owner::SessionUser, + Some(_) => unreachable!(), + None => { + match self.parse_identifier(false) { + Ok(ident) => Owner::Ident(ident), + Err(e) => { + return Err(ParserError::ParserError(format!("Expected: CURRENT_USER, CURRENT_ROLE, SESSION_USER or identifier after OWNER TO. {e}"))) + } + } + }, + }; + Ok(owner) + } + + /// ```sql + /// CREATE POLICY name ON table_name [ AS { PERMISSIVE | RESTRICTIVE } ] + /// [ FOR { ALL | SELECT | INSERT | UPDATE | DELETE } ] + /// [ TO { role_name | PUBLIC | CURRENT_USER | CURRENT_ROLE | SESSION_USER } [, ...] ] + /// [ USING ( using_expression ) ] + /// [ WITH CHECK ( with_check_expression ) ] + /// ``` + /// + /// [PostgreSQL Documentation](https://www.postgresql.org/docs/current/sql-createpolicy.html) + pub fn parse_create_policy(&mut self) -> Result { + let name = self.parse_identifier(false)?; + self.expect_keyword(Keyword::ON)?; + let table_name = self.parse_object_name(false)?; + + let policy_type = if self.parse_keyword(Keyword::AS) { + let keyword = + self.expect_one_of_keywords(&[Keyword::PERMISSIVE, Keyword::RESTRICTIVE])?; + Some(match keyword { + Keyword::PERMISSIVE => CreatePolicyType::Permissive, + Keyword::RESTRICTIVE => CreatePolicyType::Restrictive, + _ => unreachable!(), + }) + } else { + None + }; + + let command = if self.parse_keyword(Keyword::FOR) { + let keyword = self.expect_one_of_keywords(&[ + Keyword::ALL, + Keyword::SELECT, + Keyword::INSERT, + Keyword::UPDATE, + Keyword::DELETE, + ])?; + Some(match keyword { + Keyword::ALL => CreatePolicyCommand::All, + Keyword::SELECT => CreatePolicyCommand::Select, + Keyword::INSERT => CreatePolicyCommand::Insert, + Keyword::UPDATE => CreatePolicyCommand::Update, + Keyword::DELETE => CreatePolicyCommand::Delete, + _ => unreachable!(), + }) + } else { + None + }; + + let to = if self.parse_keyword(Keyword::TO) { + Some(self.parse_comma_separated(|p| p.parse_owner())?) + } else { + None + }; + + let using = if self.parse_keyword(Keyword::USING) { + self.expect_token(&Token::LParen)?; + let expr = self.parse_expr()?; + self.expect_token(&Token::RParen)?; + Some(expr) + } else { + None + }; + + let with_check = if self.parse_keywords(&[Keyword::WITH, Keyword::CHECK]) { + self.expect_token(&Token::LParen)?; + let expr = self.parse_expr()?; + self.expect_token(&Token::RParen)?; + Some(expr) + } else { + None + }; + + Ok(CreatePolicy { + name, + table_name, + policy_type, + command, + to, + using, + with_check, + }) + } + pub fn parse_drop(&mut self) -> Result { // MySQL dialect supports `TEMPORARY` let temporary = dialect_of!(self is MySqlDialect | GenericDialect | DuckDbDialect) @@ -6941,21 +7043,7 @@ impl<'a> Parser<'a> { } else if dialect_of!(self is PostgreSqlDialect | GenericDialect) && self.parse_keywords(&[Keyword::OWNER, Keyword::TO]) { - let new_owner = match self.parse_one_of_keywords(&[Keyword::CURRENT_USER, Keyword::CURRENT_ROLE, Keyword::SESSION_USER]) { - Some(Keyword::CURRENT_USER) => Owner::CurrentUser, - Some(Keyword::CURRENT_ROLE) => Owner::CurrentRole, - Some(Keyword::SESSION_USER) => Owner::SessionUser, - Some(_) => unreachable!(), - None => { - match self.parse_identifier(false) { - Ok(ident) => Owner::Ident(ident), - Err(e) => { - return Err(ParserError::ParserError(format!("Expected: CURRENT_USER, CURRENT_ROLE, SESSION_USER or identifier after OWNER TO. {e}"))) - } - } - }, - }; - + let new_owner = self.parse_owner()?; AlterTableOperation::OwnerTo { new_owner } } else if dialect_of!(self is ClickHouseDialect|GenericDialect) && self.parse_keyword(Keyword::ATTACH) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 9aa76882a..711070034 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -10987,3 +10987,105 @@ fn parse_explain_with_option_list() { Some(utility_options), ); } + +#[test] +fn test_create_policy() { + let sql = concat!( + "CREATE POLICY my_policy ON my_table ", + "AS PERMISSIVE FOR SELECT ", + "TO my_role, CURRENT_USER ", + "USING (c0 = 1) ", + "WITH CHECK (true)" + ); + + match all_dialects().verified_stmt(sql) { + Statement::CreatePolicy { + name, + table_name, + to, + using, + with_check, + .. + } => { + assert_eq!(name.to_string(), "my_policy"); + assert_eq!(table_name.to_string(), "my_table"); + assert_eq!( + to, + Some(vec![ + Owner::Ident(Ident::new("my_role")), + Owner::CurrentUser + ]) + ); + assert_eq!( + using, + Some(Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("c0"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("1".parse().unwrap(), false))), + }) + ); + assert_eq!(with_check, Some(Expr::Value(Value::Boolean(true)))); + } + _ => unreachable!(), + } + + // USING with SELECT query + all_dialects().verified_stmt(concat!( + "CREATE POLICY my_policy ON my_table ", + "AS PERMISSIVE FOR SELECT ", + "TO my_role, CURRENT_USER ", + "USING (c0 IN (SELECT column FROM t0)) ", + "WITH CHECK (true)" + )); + // omit AS / FOR / TO / USING / WITH CHECK clauses is allowed + all_dialects().verified_stmt("CREATE POLICY my_policy ON my_table"); + + // missing table name + assert_eq!( + all_dialects() + .parse_sql_statements("CREATE POLICY my_policy") + .unwrap_err() + .to_string(), + "sql parser error: Expected: ON, found: EOF" + ); + // missing policy type + assert_eq!( + all_dialects() + .parse_sql_statements("CREATE POLICY my_policy ON my_table AS") + .unwrap_err() + .to_string(), + "sql parser error: Expected: one of PERMISSIVE or RESTRICTIVE, found: EOF" + ); + // missing FOR command + assert_eq!( + all_dialects() + .parse_sql_statements("CREATE POLICY my_policy ON my_table FOR") + .unwrap_err() + .to_string(), + "sql parser error: Expected: one of ALL or SELECT or INSERT or UPDATE or DELETE, found: EOF" + ); + // missing TO owners + assert_eq!( + all_dialects() + .parse_sql_statements("CREATE POLICY my_policy ON my_table TO") + .unwrap_err() + .to_string(), + "sql parser error: Expected: CURRENT_USER, CURRENT_ROLE, SESSION_USER or identifier after OWNER TO. sql parser error: Expected: identifier, found: EOF" + ); + // missing USING expression + assert_eq!( + all_dialects() + .parse_sql_statements("CREATE POLICY my_policy ON my_table USING") + .unwrap_err() + .to_string(), + "sql parser error: Expected: (, found: EOF" + ); + // missing WITH CHECK expression + assert_eq!( + all_dialects() + .parse_sql_statements("CREATE POLICY my_policy ON my_table WITH CHECK") + .unwrap_err() + .to_string(), + "sql parser error: Expected: (, found: EOF" + ); +}