Skip to content

Commit

Permalink
feat: support different USE statement syntaxes (#1387)
Browse files Browse the repository at this point in the history
  • Loading branch information
kacpermuda authored Aug 23, 2024
1 parent 19e694a commit 7282ce2
Show file tree
Hide file tree
Showing 11 changed files with 386 additions and 17 deletions.
27 changes: 27 additions & 0 deletions src/ast/dcl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,30 @@ impl fmt::Display for AlterRoleOperation {
}
}
}

/// A `USE` (`Statement::Use`) operation
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum Use {
Catalog(ObjectName), // e.g. `USE CATALOG foo.bar`
Schema(ObjectName), // e.g. `USE SCHEMA foo.bar`
Database(ObjectName), // e.g. `USE DATABASE foo.bar`
Warehouse(ObjectName), // e.g. `USE WAREHOUSE foo.bar`
Object(ObjectName), // e.g. `USE foo.bar`
Default, // e.g. `USE DEFAULT`
}

impl fmt::Display for Use {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("USE ")?;
match self {
Use::Catalog(name) => write!(f, "CATALOG {}", name),
Use::Schema(name) => write!(f, "SCHEMA {}", name),
Use::Database(name) => write!(f, "DATABASE {}", name),
Use::Warehouse(name) => write!(f, "WAREHOUSE {}", name),
Use::Object(name) => write!(f, "{}", name),
Use::Default => write!(f, "DEFAULT"),
}
}
}
13 changes: 4 additions & 9 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub use self::data_type::{
ArrayElemTypeDef, CharLengthUnits, CharacterLength, DataType, ExactNumberInfo,
StructBracketKind, TimezoneInfo,
};
pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue};
pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue, Use};
pub use self::ddl::{
AlterColumnOperation, AlterIndexOperation, AlterTableOperation, ColumnDef, ColumnOption,
ColumnOptionDef, ConstraintCharacteristics, Deduplicate, DeferrableInitial, GeneratedAs,
Expand Down Expand Up @@ -2515,11 +2515,9 @@ pub enum Statement {
/// Note: this is a MySQL-specific statement.
ShowCollation { filter: Option<ShowStatementFilter> },
/// ```sql
/// USE
/// `USE ...`
/// ```
///
/// Note: This is a MySQL-specific statement.
Use { db_name: Ident },
Use(Use),
/// ```sql
/// START [ TRANSACTION | WORK ] | START TRANSACTION } ...
/// ```
Expand Down Expand Up @@ -4125,10 +4123,7 @@ impl fmt::Display for Statement {
}
Ok(())
}
Statement::Use { db_name } => {
write!(f, "USE {db_name}")?;
Ok(())
}
Statement::Use(use_expr) => use_expr.fmt(f),
Statement::ShowCollation { filter } => {
write!(f, "SHOW COLLATION")?;
if let Some(filter) = filter {
Expand Down
2 changes: 2 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ define_keywords!(
CASCADED,
CASE,
CAST,
CATALOG,
CEIL,
CEILING,
CENTURY,
Expand Down Expand Up @@ -804,6 +805,7 @@ define_keywords!(
VIEW,
VIRTUAL,
VOLATILE,
WAREHOUSE,
WEEK,
WHEN,
WHENEVER,
Expand Down
27 changes: 25 additions & 2 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9264,8 +9264,31 @@ impl<'a> Parser<'a> {
}

pub fn parse_use(&mut self) -> Result<Statement, ParserError> {
let db_name = self.parse_identifier(false)?;
Ok(Statement::Use { db_name })
// Determine which keywords are recognized by the current dialect
let parsed_keyword = if dialect_of!(self is HiveDialect) {
// HiveDialect accepts USE DEFAULT; statement without any db specified
if self.parse_keyword(Keyword::DEFAULT) {
return Ok(Statement::Use(Use::Default));
}
None // HiveDialect doesn't expect any other specific keyword after `USE`
} else if dialect_of!(self is DatabricksDialect) {
self.parse_one_of_keywords(&[Keyword::CATALOG, Keyword::DATABASE, Keyword::SCHEMA])
} else if dialect_of!(self is SnowflakeDialect) {
self.parse_one_of_keywords(&[Keyword::DATABASE, Keyword::SCHEMA, Keyword::WAREHOUSE])
} else {
None // No specific keywords for other dialects, including GenericDialect
};

let obj_name = self.parse_object_name(false)?;
let result = match parsed_keyword {
Some(Keyword::CATALOG) => Use::Catalog(obj_name),
Some(Keyword::DATABASE) => Use::Database(obj_name),
Some(Keyword::SCHEMA) => Use::Schema(obj_name),
Some(Keyword::WAREHOUSE) => Use::Warehouse(obj_name),
_ => Use::Object(obj_name),
};

Ok(Statement::Use(result))
}

pub fn parse_table_and_joins(&mut self) -> Result<TableWithJoins, ParserError> {
Expand Down
33 changes: 33 additions & 0 deletions tests/sqlparser_clickhouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,39 @@ fn test_prewhere() {
}
}

#[test]
fn parse_use() {
let valid_object_names = [
"mydb",
"SCHEMA",
"DATABASE",
"CATALOG",
"WAREHOUSE",
"DEFAULT",
];
let quote_styles = ['"', '`'];

for object_name in &valid_object_names {
// Test single identifier without quotes
assert_eq!(
clickhouse().verified_stmt(&format!("USE {}", object_name)),
Statement::Use(Use::Object(ObjectName(vec![Ident::new(
object_name.to_string()
)])))
);
for &quote in &quote_styles {
// Test single identifier with different type of quotes
assert_eq!(
clickhouse().verified_stmt(&format!("USE {0}{1}{0}", quote, object_name)),
Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote(
quote,
object_name.to_string(),
)])))
);
}
}
}

#[test]
fn test_query_with_format_clause() {
let format_options = vec!["TabSeparated", "JSONCompact", "NULL"];
Expand Down
74 changes: 74 additions & 0 deletions tests/sqlparser_databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,77 @@ fn test_values_clause() {
// TODO: support this example from https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-values.html#examples
// databricks().verified_query("VALUES 1, 2, 3");
}

#[test]
fn parse_use() {
let valid_object_names = ["mydb", "WAREHOUSE", "DEFAULT"];
let quote_styles = ['"', '`'];

for object_name in &valid_object_names {
// Test single identifier without quotes
assert_eq!(
databricks().verified_stmt(&format!("USE {}", object_name)),
Statement::Use(Use::Object(ObjectName(vec![Ident::new(
object_name.to_string()
)])))
);
for &quote in &quote_styles {
// Test single identifier with different type of quotes
assert_eq!(
databricks().verified_stmt(&format!("USE {0}{1}{0}", quote, object_name)),
Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote(
quote,
object_name.to_string(),
)])))
);
}
}

for &quote in &quote_styles {
// Test single identifier with keyword and different type of quotes
assert_eq!(
databricks().verified_stmt(&format!("USE CATALOG {0}my_catalog{0}", quote)),
Statement::Use(Use::Catalog(ObjectName(vec![Ident::with_quote(
quote,
"my_catalog".to_string(),
)])))
);
assert_eq!(
databricks().verified_stmt(&format!("USE DATABASE {0}my_database{0}", quote)),
Statement::Use(Use::Database(ObjectName(vec![Ident::with_quote(
quote,
"my_database".to_string(),
)])))
);
assert_eq!(
databricks().verified_stmt(&format!("USE SCHEMA {0}my_schema{0}", quote)),
Statement::Use(Use::Schema(ObjectName(vec![Ident::with_quote(
quote,
"my_schema".to_string(),
)])))
);
}

// Test single identifier with keyword and no quotes
assert_eq!(
databricks().verified_stmt("USE CATALOG my_catalog"),
Statement::Use(Use::Catalog(ObjectName(vec![Ident::new("my_catalog")])))
);
assert_eq!(
databricks().verified_stmt("USE DATABASE my_schema"),
Statement::Use(Use::Database(ObjectName(vec![Ident::new("my_schema")])))
);
assert_eq!(
databricks().verified_stmt("USE SCHEMA my_schema"),
Statement::Use(Use::Schema(ObjectName(vec![Ident::new("my_schema")])))
);

// Test invalid syntax - missing identifier
let invalid_cases = ["USE SCHEMA", "USE DATABASE", "USE CATALOG"];
for sql in &invalid_cases {
assert_eq!(
databricks().parse_sql_statements(sql).unwrap_err(),
ParserError::ParserError("Expected: identifier, found: EOF".to_string()),
);
}
}
52 changes: 52 additions & 0 deletions tests/sqlparser_duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -756,3 +756,55 @@ fn test_duckdb_union_datatype() {
stmt
);
}

#[test]
fn parse_use() {
let valid_object_names = [
"mydb",
"SCHEMA",
"DATABASE",
"CATALOG",
"WAREHOUSE",
"DEFAULT",
];
let quote_styles = ['"', '\''];

for object_name in &valid_object_names {
// Test single identifier without quotes
assert_eq!(
duckdb().verified_stmt(&format!("USE {}", object_name)),
Statement::Use(Use::Object(ObjectName(vec![Ident::new(
object_name.to_string()
)])))
);
for &quote in &quote_styles {
// Test single identifier with different type of quotes
assert_eq!(
duckdb().verified_stmt(&format!("USE {0}{1}{0}", quote, object_name)),
Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote(
quote,
object_name.to_string(),
)])))
);
}
}

for &quote in &quote_styles {
// Test double identifier with different type of quotes
assert_eq!(
duckdb().verified_stmt(&format!("USE {0}CATALOG{0}.{0}my_schema{0}", quote)),
Statement::Use(Use::Object(ObjectName(vec![
Ident::with_quote(quote, "CATALOG"),
Ident::with_quote(quote, "my_schema")
])))
);
}
// Test double identifier without quotes
assert_eq!(
duckdb().verified_stmt("USE mydb.my_schema"),
Statement::Use(Use::Object(ObjectName(vec![
Ident::new("mydb"),
Ident::new("my_schema")
])))
);
}
32 changes: 31 additions & 1 deletion tests/sqlparser_hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use sqlparser::ast::{
CreateFunctionBody, CreateFunctionUsing, Expr, Function, FunctionArgumentList,
FunctionArguments, Ident, ObjectName, OneOrManyWithParens, SelectItem, Statement, TableFactor,
UnaryOperator, Value,
UnaryOperator, Use, Value,
};
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
use sqlparser::parser::ParserError;
Expand Down Expand Up @@ -401,6 +401,36 @@ fn parse_delimited_identifiers() {
//TODO verified_stmt(r#"UPDATE foo SET "bar" = 5"#);
}

#[test]
fn parse_use() {
let valid_object_names = ["mydb", "SCHEMA", "DATABASE", "CATALOG", "WAREHOUSE"];
let quote_styles = ['\'', '"', '`'];
for object_name in &valid_object_names {
// Test single identifier without quotes
assert_eq!(
hive().verified_stmt(&format!("USE {}", object_name)),
Statement::Use(Use::Object(ObjectName(vec![Ident::new(
object_name.to_string()
)])))
);
for &quote in &quote_styles {
// Test single identifier with different type of quotes
assert_eq!(
hive().verified_stmt(&format!("USE {}{}{}", quote, object_name, quote)),
Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote(
quote,
object_name.to_string(),
)])))
);
}
}
// Test DEFAULT keyword that is special case in Hive
assert_eq!(
hive().verified_stmt("USE DEFAULT"),
Statement::Use(Use::Default)
);
}

fn hive() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(HiveDialect {})],
Expand Down
32 changes: 32 additions & 0 deletions tests/sqlparser_mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,38 @@ fn parse_mssql_declare() {
);
}

#[test]
fn parse_use() {
let valid_object_names = [
"mydb",
"SCHEMA",
"DATABASE",
"CATALOG",
"WAREHOUSE",
"DEFAULT",
];
let quote_styles = ['\'', '"'];
for object_name in &valid_object_names {
// Test single identifier without quotes
assert_eq!(
ms().verified_stmt(&format!("USE {}", object_name)),
Statement::Use(Use::Object(ObjectName(vec![Ident::new(
object_name.to_string()
)])))
);
for &quote in &quote_styles {
// Test single identifier with different type of quotes
assert_eq!(
ms().verified_stmt(&format!("USE {}{}{}", quote, object_name, quote)),
Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote(
quote,
object_name.to_string(),
)])))
);
}
}
}

fn ms() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(MsSqlDialect {})],
Expand Down
Loading

0 comments on commit 7282ce2

Please sign in to comment.