diff --git a/src/parser/mod.rs b/src/parser/mod.rs index cd9be1d8f..b4c0487b4 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -9416,27 +9416,35 @@ impl<'a> Parser<'a> { } } + /// Parse a `SET ROLE` statement. Expects SET to be consumed already. + fn parse_set_role(&mut self, modifier: Option) -> Result { + self.expect_keyword(Keyword::ROLE)?; + let context_modifier = match modifier { + Some(Keyword::LOCAL) => ContextModifier::Local, + Some(Keyword::SESSION) => ContextModifier::Session, + _ => ContextModifier::None, + }; + + let role_name = if self.parse_keyword(Keyword::NONE) { + None + } else { + Some(self.parse_identifier(false)?) + }; + Ok(Statement::SetRole { + context_modifier, + role_name, + }) + } + pub fn parse_set(&mut self) -> Result { let modifier = self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]); if let Some(Keyword::HIVEVAR) = modifier { self.expect_token(&Token::Colon)?; - } else if self.parse_keyword(Keyword::ROLE) { - let context_modifier = match modifier { - Some(Keyword::LOCAL) => ContextModifier::Local, - Some(Keyword::SESSION) => ContextModifier::Session, - _ => ContextModifier::None, - }; - - let role_name = if self.parse_keyword(Keyword::NONE) { - None - } else { - Some(self.parse_identifier(false)?) - }; - return Ok(Statement::SetRole { - context_modifier, - role_name, - }); + } else if let Some(set_role_stmt) = + self.maybe_parse(|parser| parser.parse_set_role(modifier)) + { + return Ok(set_role_stmt); } let variables = if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 5327880a4..447fea318 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -7665,6 +7665,30 @@ fn parse_set_variable() { one_statement_parses_to("SET SOMETHING TO '1'", "SET SOMETHING = '1'"); } +#[test] +fn parse_set_role_as_variable() { + match verified_stmt("SET role = 'foobar'") { + Statement::SetVariable { + local, + hivevar, + variables, + value, + } => { + assert!(!local); + assert!(!hivevar); + assert_eq!( + variables, + OneOrManyWithParens::One(ObjectName(vec!["role".into()])) + ); + assert_eq!( + value, + vec![Expr::Value(Value::SingleQuotedString("foobar".into()))] + ); + } + _ => unreachable!(), + } +} + #[test] fn parse_double_colon_cast_at_timezone() { let sql = "SELECT '2001-01-01T00:00:00.000Z'::TIMESTAMP AT TIME ZONE 'Europe/Brussels' FROM t";