diff --git a/src/ast/dcl.rs b/src/ast/dcl.rs index f90de34d4..1b0a77095 100644 --- a/src/ast/dcl.rs +++ b/src/ast/dcl.rs @@ -193,3 +193,30 @@ impl fmt::Display for AlterRoleOperation { } } } + +/// A `USE` (`Statement::Use`) operation +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum Use { + Catalog(ObjectName), // e.g. `USE CATALOG foo.bar` + Schema(ObjectName), // e.g. `USE SCHEMA foo.bar` + Database(ObjectName), // e.g. `USE DATABASE foo.bar` + Warehouse(ObjectName), // e.g. `USE WAREHOUSE foo.bar` + Object(ObjectName), // e.g. `USE foo.bar` + Default, // e.g. `USE DEFAULT` +} + +impl fmt::Display for Use { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("USE ")?; + match self { + Use::Catalog(name) => write!(f, "CATALOG {}", name), + Use::Schema(name) => write!(f, "SCHEMA {}", name), + Use::Database(name) => write!(f, "DATABASE {}", name), + Use::Warehouse(name) => write!(f, "WAREHOUSE {}", name), + Use::Object(name) => write!(f, "{}", name), + Use::Default => write!(f, "DEFAULT"), + } + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 4f9aac885..8a56f3158 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -31,7 +31,7 @@ pub use self::data_type::{ ArrayElemTypeDef, CharLengthUnits, CharacterLength, DataType, ExactNumberInfo, StructBracketKind, TimezoneInfo, }; -pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue}; +pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue, Use}; pub use self::ddl::{ AlterColumnOperation, AlterIndexOperation, AlterTableOperation, ColumnDef, ColumnOption, ColumnOptionDef, ConstraintCharacteristics, Deduplicate, DeferrableInitial, GeneratedAs, @@ -2515,11 +2515,9 @@ pub enum Statement { /// Note: this is a MySQL-specific statement. ShowCollation { filter: Option }, /// ```sql - /// USE + /// `USE ...` /// ``` - /// - /// Note: This is a MySQL-specific statement. - Use { db_name: Ident }, + Use(Use), /// ```sql /// START [ TRANSACTION | WORK ] | START TRANSACTION } ... /// ``` @@ -4125,10 +4123,7 @@ impl fmt::Display for Statement { } Ok(()) } - Statement::Use { db_name } => { - write!(f, "USE {db_name}")?; - Ok(()) - } + Statement::Use(use_expr) => use_expr.fmt(f), Statement::ShowCollation { filter } => { write!(f, "SHOW COLLATION")?; if let Some(filter) = filter { diff --git a/src/keywords.rs b/src/keywords.rs index acb913d57..d2dcc57d1 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -137,6 +137,7 @@ define_keywords!( CASCADED, CASE, CAST, + CATALOG, CEIL, CEILING, CENTURY, @@ -804,6 +805,7 @@ define_keywords!( VIEW, VIRTUAL, VOLATILE, + WAREHOUSE, WEEK, WHEN, WHENEVER, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 1eff8f7d5..8f8c3f050 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -9264,8 +9264,31 @@ impl<'a> Parser<'a> { } pub fn parse_use(&mut self) -> Result { - let db_name = self.parse_identifier(false)?; - Ok(Statement::Use { db_name }) + // Determine which keywords are recognized by the current dialect + let parsed_keyword = if dialect_of!(self is HiveDialect) { + // HiveDialect accepts USE DEFAULT; statement without any db specified + if self.parse_keyword(Keyword::DEFAULT) { + return Ok(Statement::Use(Use::Default)); + } + None // HiveDialect doesn't expect any other specific keyword after `USE` + } else if dialect_of!(self is DatabricksDialect) { + self.parse_one_of_keywords(&[Keyword::CATALOG, Keyword::DATABASE, Keyword::SCHEMA]) + } else if dialect_of!(self is SnowflakeDialect) { + self.parse_one_of_keywords(&[Keyword::DATABASE, Keyword::SCHEMA, Keyword::WAREHOUSE]) + } else { + None // No specific keywords for other dialects, including GenericDialect + }; + + let obj_name = self.parse_object_name(false)?; + let result = match parsed_keyword { + Some(Keyword::CATALOG) => Use::Catalog(obj_name), + Some(Keyword::DATABASE) => Use::Database(obj_name), + Some(Keyword::SCHEMA) => Use::Schema(obj_name), + Some(Keyword::WAREHOUSE) => Use::Warehouse(obj_name), + _ => Use::Object(obj_name), + }; + + Ok(Statement::Use(result)) } pub fn parse_table_and_joins(&mut self) -> Result { diff --git a/tests/sqlparser_clickhouse.rs b/tests/sqlparser_clickhouse.rs index fe255cda5..c8157bced 100644 --- a/tests/sqlparser_clickhouse.rs +++ b/tests/sqlparser_clickhouse.rs @@ -1232,6 +1232,39 @@ fn test_prewhere() { } } +#[test] +fn parse_use() { + let valid_object_names = [ + "mydb", + "SCHEMA", + "DATABASE", + "CATALOG", + "WAREHOUSE", + "DEFAULT", + ]; + let quote_styles = ['"', '`']; + + for object_name in &valid_object_names { + // Test single identifier without quotes + assert_eq!( + clickhouse().verified_stmt(&format!("USE {}", object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::new( + object_name.to_string() + )]))) + ); + for "e in "e_styles { + // Test single identifier with different type of quotes + assert_eq!( + clickhouse().verified_stmt(&format!("USE {0}{1}{0}", quote, object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote( + quote, + object_name.to_string(), + )]))) + ); + } + } +} + #[test] fn test_query_with_format_clause() { let format_options = vec!["TabSeparated", "JSONCompact", "NULL"]; diff --git a/tests/sqlparser_databricks.rs b/tests/sqlparser_databricks.rs index 280b97b49..ee0cf2d7d 100644 --- a/tests/sqlparser_databricks.rs +++ b/tests/sqlparser_databricks.rs @@ -189,3 +189,77 @@ fn test_values_clause() { // TODO: support this example from https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-values.html#examples // databricks().verified_query("VALUES 1, 2, 3"); } + +#[test] +fn parse_use() { + let valid_object_names = ["mydb", "WAREHOUSE", "DEFAULT"]; + let quote_styles = ['"', '`']; + + for object_name in &valid_object_names { + // Test single identifier without quotes + assert_eq!( + databricks().verified_stmt(&format!("USE {}", object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::new( + object_name.to_string() + )]))) + ); + for "e in "e_styles { + // Test single identifier with different type of quotes + assert_eq!( + databricks().verified_stmt(&format!("USE {0}{1}{0}", quote, object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote( + quote, + object_name.to_string(), + )]))) + ); + } + } + + for "e in "e_styles { + // Test single identifier with keyword and different type of quotes + assert_eq!( + databricks().verified_stmt(&format!("USE CATALOG {0}my_catalog{0}", quote)), + Statement::Use(Use::Catalog(ObjectName(vec![Ident::with_quote( + quote, + "my_catalog".to_string(), + )]))) + ); + assert_eq!( + databricks().verified_stmt(&format!("USE DATABASE {0}my_database{0}", quote)), + Statement::Use(Use::Database(ObjectName(vec![Ident::with_quote( + quote, + "my_database".to_string(), + )]))) + ); + assert_eq!( + databricks().verified_stmt(&format!("USE SCHEMA {0}my_schema{0}", quote)), + Statement::Use(Use::Schema(ObjectName(vec![Ident::with_quote( + quote, + "my_schema".to_string(), + )]))) + ); + } + + // Test single identifier with keyword and no quotes + assert_eq!( + databricks().verified_stmt("USE CATALOG my_catalog"), + Statement::Use(Use::Catalog(ObjectName(vec![Ident::new("my_catalog")]))) + ); + assert_eq!( + databricks().verified_stmt("USE DATABASE my_schema"), + Statement::Use(Use::Database(ObjectName(vec![Ident::new("my_schema")]))) + ); + assert_eq!( + databricks().verified_stmt("USE SCHEMA my_schema"), + Statement::Use(Use::Schema(ObjectName(vec![Ident::new("my_schema")]))) + ); + + // Test invalid syntax - missing identifier + let invalid_cases = ["USE SCHEMA", "USE DATABASE", "USE CATALOG"]; + for sql in &invalid_cases { + assert_eq!( + databricks().parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError("Expected: identifier, found: EOF".to_string()), + ); + } +} diff --git a/tests/sqlparser_duckdb.rs b/tests/sqlparser_duckdb.rs index 6e6c4e230..488fddfd3 100644 --- a/tests/sqlparser_duckdb.rs +++ b/tests/sqlparser_duckdb.rs @@ -756,3 +756,55 @@ fn test_duckdb_union_datatype() { stmt ); } + +#[test] +fn parse_use() { + let valid_object_names = [ + "mydb", + "SCHEMA", + "DATABASE", + "CATALOG", + "WAREHOUSE", + "DEFAULT", + ]; + let quote_styles = ['"', '\'']; + + for object_name in &valid_object_names { + // Test single identifier without quotes + assert_eq!( + duckdb().verified_stmt(&format!("USE {}", object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::new( + object_name.to_string() + )]))) + ); + for "e in "e_styles { + // Test single identifier with different type of quotes + assert_eq!( + duckdb().verified_stmt(&format!("USE {0}{1}{0}", quote, object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote( + quote, + object_name.to_string(), + )]))) + ); + } + } + + for "e in "e_styles { + // Test double identifier with different type of quotes + assert_eq!( + duckdb().verified_stmt(&format!("USE {0}CATALOG{0}.{0}my_schema{0}", quote)), + Statement::Use(Use::Object(ObjectName(vec![ + Ident::with_quote(quote, "CATALOG"), + Ident::with_quote(quote, "my_schema") + ]))) + ); + } + // Test double identifier without quotes + assert_eq!( + duckdb().verified_stmt("USE mydb.my_schema"), + Statement::Use(Use::Object(ObjectName(vec![ + Ident::new("mydb"), + Ident::new("my_schema") + ]))) + ); +} diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 157dad060..bd242035e 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -18,7 +18,7 @@ use sqlparser::ast::{ CreateFunctionBody, CreateFunctionUsing, Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OneOrManyWithParens, SelectItem, Statement, TableFactor, - UnaryOperator, Value, + UnaryOperator, Use, Value, }; use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect}; use sqlparser::parser::ParserError; @@ -401,6 +401,36 @@ fn parse_delimited_identifiers() { //TODO verified_stmt(r#"UPDATE foo SET "bar" = 5"#); } +#[test] +fn parse_use() { + let valid_object_names = ["mydb", "SCHEMA", "DATABASE", "CATALOG", "WAREHOUSE"]; + let quote_styles = ['\'', '"', '`']; + for object_name in &valid_object_names { + // Test single identifier without quotes + assert_eq!( + hive().verified_stmt(&format!("USE {}", object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::new( + object_name.to_string() + )]))) + ); + for "e in "e_styles { + // Test single identifier with different type of quotes + assert_eq!( + hive().verified_stmt(&format!("USE {}{}{}", quote, object_name, quote)), + Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote( + quote, + object_name.to_string(), + )]))) + ); + } + } + // Test DEFAULT keyword that is special case in Hive + assert_eq!( + hive().verified_stmt("USE DEFAULT"), + Statement::Use(Use::Default) + ); +} + fn hive() -> TestedDialects { TestedDialects { dialects: vec![Box::new(HiveDialect {})], diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 3e8b6afbf..5c2ec8763 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -621,6 +621,38 @@ fn parse_mssql_declare() { ); } +#[test] +fn parse_use() { + let valid_object_names = [ + "mydb", + "SCHEMA", + "DATABASE", + "CATALOG", + "WAREHOUSE", + "DEFAULT", + ]; + let quote_styles = ['\'', '"']; + for object_name in &valid_object_names { + // Test single identifier without quotes + assert_eq!( + ms().verified_stmt(&format!("USE {}", object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::new( + object_name.to_string() + )]))) + ); + for "e in "e_styles { + // Test single identifier with different type of quotes + assert_eq!( + ms().verified_stmt(&format!("USE {}{}{}", quote, object_name, quote)), + Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote( + quote, + object_name.to_string(), + )]))) + ); + } + } +} + fn ms() -> TestedDialects { TestedDialects { dialects: vec![Box::new(MsSqlDialect {})], diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 397a722b5..33587c35a 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -444,12 +444,35 @@ fn parse_show_collation() { #[test] fn parse_use() { - assert_eq!( - mysql_and_generic().verified_stmt("USE mydb"), - Statement::Use { - db_name: Ident::new("mydb") + let valid_object_names = [ + "mydb", + "SCHEMA", + "DATABASE", + "CATALOG", + "WAREHOUSE", + "DEFAULT", + ]; + let quote_styles = ['\'', '"', '`']; + for object_name in &valid_object_names { + // Test single identifier without quotes + assert_eq!( + mysql_and_generic().verified_stmt(&format!("USE {}", object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::new( + object_name.to_string() + )]))) + ); + for "e in "e_styles { + // Test single identifier with different type of quotes + assert_eq!( + mysql_and_generic() + .verified_stmt(&format!("USE {}{}{}", quote, object_name, quote)), + Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote( + quote, + object_name.to_string(), + )]))) + ); } - ); + } } #[test] diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index a4f29c04f..d0876fc50 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -2322,3 +2322,81 @@ fn parse_explain_table() { _ => panic!("Unexpected Statement, must be ExplainTable"), } } + +#[test] +fn parse_use() { + let valid_object_names = ["mydb", "CATALOG", "DEFAULT"]; + let quote_styles = ['\'', '"', '`']; + for object_name in &valid_object_names { + // Test single identifier without quotes + std::assert_eq!( + snowflake().verified_stmt(&format!("USE {}", object_name)), + Statement::Use(Use::Object(ObjectName(vec![Ident::new( + object_name.to_string() + )]))) + ); + for "e in "e_styles { + // Test single identifier with different type of quotes + std::assert_eq!( + snowflake().verified_stmt(&format!("USE {}{}{}", quote, object_name, quote)), + Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote( + quote, + object_name.to_string(), + )]))) + ); + } + } + + for "e in "e_styles { + // Test double identifier with different type of quotes + std::assert_eq!( + snowflake().verified_stmt(&format!("USE {0}CATALOG{0}.{0}my_schema{0}", quote)), + Statement::Use(Use::Object(ObjectName(vec![ + Ident::with_quote(quote, "CATALOG"), + Ident::with_quote(quote, "my_schema") + ]))) + ); + } + // Test double identifier without quotes + std::assert_eq!( + snowflake().verified_stmt("USE mydb.my_schema"), + Statement::Use(Use::Object(ObjectName(vec![ + Ident::new("mydb"), + Ident::new("my_schema") + ]))) + ); + + for "e in "e_styles { + // Test single and double identifier with keyword and different type of quotes + std::assert_eq!( + snowflake().verified_stmt(&format!("USE DATABASE {0}my_database{0}", quote)), + Statement::Use(Use::Database(ObjectName(vec![Ident::with_quote( + quote, + "my_database".to_string(), + )]))) + ); + std::assert_eq!( + snowflake().verified_stmt(&format!("USE SCHEMA {0}my_schema{0}", quote)), + Statement::Use(Use::Schema(ObjectName(vec![Ident::with_quote( + quote, + "my_schema".to_string(), + )]))) + ); + std::assert_eq!( + snowflake().verified_stmt(&format!("USE SCHEMA {0}CATALOG{0}.{0}my_schema{0}", quote)), + Statement::Use(Use::Schema(ObjectName(vec![ + Ident::with_quote(quote, "CATALOG"), + Ident::with_quote(quote, "my_schema") + ]))) + ); + } + + // Test invalid syntax - missing identifier + let invalid_cases = ["USE SCHEMA", "USE DATABASE", "USE WAREHOUSE"]; + for sql in &invalid_cases { + std::assert_eq!( + snowflake().parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError("Expected: identifier, found: EOF".to_string()), + ); + } +}