From 40a1360146728def118e7249f737d9113557ff57 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sat, 29 Oct 2022 14:49:28 +0200 Subject: [PATCH] Omit table names in idents when there is only one table in SELECT --- prql-compiler/src/ir/mod.rs | 2 ++ prql-compiler/src/ir/table_counter.rs | 27 +++++++++++++++ prql-compiler/src/lib.rs | 48 ++++++++++++--------------- prql-compiler/src/sql/anchor.rs | 2 +- prql-compiler/src/sql/codegen.rs | 15 +++++---- prql-compiler/src/sql/translator.rs | 14 ++++++-- 6 files changed, 72 insertions(+), 36 deletions(-) create mode 100644 prql-compiler/src/ir/table_counter.rs diff --git a/prql-compiler/src/ir/mod.rs b/prql-compiler/src/ir/mod.rs index 13c3e14b9be1..7ea5f62191fb 100644 --- a/prql-compiler/src/ir/mod.rs +++ b/prql-compiler/src/ir/mod.rs @@ -1,10 +1,12 @@ mod expr; mod id_gen; mod ir_fold; +mod table_counter; pub use expr::{Expr, ExprKind, UnOp}; pub use id_gen::IdGenerator; pub use ir_fold::*; +pub use table_counter::TableCounter; use serde::{Deserialize, Serialize}; diff --git a/prql-compiler/src/ir/table_counter.rs b/prql-compiler/src/ir/table_counter.rs new file mode 100644 index 000000000000..91ca9117e7d5 --- /dev/null +++ b/prql-compiler/src/ir/table_counter.rs @@ -0,0 +1,27 @@ +use std::collections::HashSet; + +use super::{IrFold, TId, Transform}; + +/// Folder that counts the number of table referenced in a PRQL query. +#[derive(Debug, Default)] +pub struct TableCounter { + tables: HashSet, +} + +impl TableCounter { + pub fn count(&self) -> usize { + self.tables.len() + } +} + +impl IrFold for TableCounter { + fn fold_transforms(&mut self, transforms: Vec) -> anyhow::Result> { + for transform in &transforms { + if let Transform::Join { with: tid, .. } | Transform::From(tid) = transform { + self.tables.insert(*tid); + } + } + + Ok(transforms) + } +} diff --git a/prql-compiler/src/lib.rs b/prql-compiler/src/lib.rs index 7dd83bf7ceff..54f481d78e59 100644 --- a/prql-compiler/src/lib.rs +++ b/prql-compiler/src/lib.rs @@ -225,7 +225,6 @@ mod test { } #[test] - #[ignore] fn test_quoting() { // GH-#822 assert_display_snapshot!((compile(r###" @@ -238,7 +237,7 @@ join some_schema.tablename [~id] "###).unwrap()), @r###" WITH "UPPER" AS ( SELECT - lower.* + * FROM lower ) @@ -287,7 +286,7 @@ select `first name` assert_display_snapshot!((compile(query).unwrap()), @r###" SELECT - invoices.* + * FROM invoices ORDER BY @@ -298,7 +297,6 @@ select `first name` } #[test] - #[ignore] fn test_ranges() { let query = r###" from employees @@ -307,7 +305,7 @@ select `first name` assert_display_snapshot!((compile(query).unwrap()), @r###" SELECT - employees.* + * FROM employees WHERE @@ -329,7 +327,7 @@ select `first name` assert_display_snapshot!((compile(query).unwrap()), @r###" SELECT - events.* + * FROM events WHERE @@ -339,7 +337,6 @@ select `first name` } #[test] - #[ignore] fn test_interval() { let query = r###" from projects @@ -348,7 +345,7 @@ select `first name` assert_display_snapshot!((compile(query).unwrap()), @r###" SELECT - projects.*, + *, start + INTERVAL 10 DAY AS first_check_in FROM projects @@ -367,7 +364,7 @@ select `first name` ] "###).unwrap()), @r###" SELECT - to_do_empty_table.*, + *, DATE '2011-02-01' AS date, TIMESTAMP '2011-02-01T10:00' AS timestamp, TIME '14:00' AS time @@ -642,7 +639,6 @@ select `first name` } #[test] - #[ignore] fn test_filter() { // https://github.com/prql/prql/issues/469 let query = r###" @@ -657,7 +653,7 @@ select `first name` filter age > 25 and age < 40 "###).unwrap()), @r###" SELECT - employees.* + * FROM employees WHERE @@ -671,7 +667,7 @@ select `first name` filter age < 40 "###).unwrap()), @r###" SELECT - employees.* + * FROM employees WHERE @@ -681,7 +677,6 @@ select `first name` } #[test] - #[ignore] fn test_nulls() { assert_display_snapshot!((compile(r###" from employees @@ -699,7 +694,7 @@ select `first name` derive amount = amount + 2 ?? 3 * 5 "###).unwrap()), @r###" SELECT - employees.*, + *, COALESCE(amount + 2, 3 * 5) AS amount FROM employees @@ -711,7 +706,7 @@ select `first name` filter first_name == null and null == last_name "###).unwrap()), @r###" SELECT - employees.* + * FROM employees WHERE @@ -725,7 +720,7 @@ select `first name` filter first_name != null and null != last_name "###).unwrap()), @r###" SELECT - employees.* + * FROM employees WHERE @@ -742,7 +737,7 @@ select `first name` take ..10 "###).unwrap()), @r###" SELECT - employees.* + * FROM employees LIMIT @@ -754,7 +749,7 @@ select `first name` take 5..10 "###).unwrap()), @r###" SELECT - employees.* + * FROM employees LIMIT @@ -766,7 +761,7 @@ select `first name` take 5.. "###).unwrap()), @r###" SELECT - employees.* + * FROM employees OFFSET 4 "###); @@ -776,7 +771,7 @@ select `first name` take 5..5 "###).unwrap()), @r###" SELECT - employees.* + * FROM employees LIMIT @@ -790,7 +785,7 @@ select `first name` take 1..5 "###).unwrap()), @r###" SELECT - employees.* + * FROM employees LIMIT @@ -1434,15 +1429,14 @@ take 20 assert_display_snapshot!((compile(query).unwrap()), @r###" SELECT - github_json.*, - github_json.`event.type` AS event_type_dotted + *, + `event.type` AS event_type_dotted FROM github_json "###); } #[test] - #[ignore] fn test_ident_escaping() { // Generic let query = r###" @@ -1452,7 +1446,7 @@ take 20 assert_display_snapshot!((compile(query).unwrap()), @r###" SELECT - "anim""ls".*, + *, "BeeName" AS "čebela", "bear's_name" AS medved FROM @@ -1469,7 +1463,7 @@ take 20 assert_display_snapshot!((compile(query).unwrap()), @r###" SELECT - `anim"ls`.*, + *, `BeeName` AS `čebela`, `bear's_name` AS medved FROM @@ -1488,7 +1482,7 @@ take 20 assert_display_snapshot!(sql, @r###" SELECT - employees.*, + *, true AS always_true FROM employees diff --git a/prql-compiler/src/sql/anchor.rs b/prql-compiler/src/sql/anchor.rs index 741cc6320eb8..a963950e2ab2 100644 --- a/prql-compiler/src/sql/anchor.rs +++ b/prql-compiler/src/sql/anchor.rs @@ -172,7 +172,7 @@ impl AnchorContext { match transform { Transform::From(tid) => { let table_def = &self.table_defs.get(tid).unwrap(); - columns = table_def.columns.iter().map(|c| c.id).collect(); + columns = table_def.columns.iter().map(|c| dbg!(c.id)).collect(); } Transform::Select(cols) => columns = cols.clone(), Transform::Aggregate(cols) => columns = cols.clone(), diff --git a/prql-compiler/src/sql/codegen.rs b/prql-compiler/src/sql/codegen.rs index 08d9ea612028..29a1ea492718 100644 --- a/prql-compiler/src/sql/codegen.rs +++ b/prql-compiler/src/sql/codegen.rs @@ -414,14 +414,17 @@ pub(super) fn translate_ident( context: &Context, ) -> Vec { let mut parts = Vec::with_capacity(4); - if let Some(relation) = relation_name { - // Special-case this for BigQuery, Ref #852 - if matches!(context.dialect.dialect(), Dialect::BigQuery) { - parts.push(relation); - } else { - parts.extend(relation.split('.').map(|s| s.to_string())); + if !context.omit_ident_prefix || column.is_none() { + if let Some(relation) = relation_name { + // Special-case this for BigQuery, Ref #852 + if matches!(context.dialect.dialect(), Dialect::BigQuery) { + parts.push(relation); + } else { + parts.extend(relation.split('.').map(|s| s.to_string())); + } } } + parts.extend(column); parts diff --git a/prql-compiler/src/sql/translator.rs b/prql-compiler/src/sql/translator.rs index c91f0f4484be..87d1369a0fe7 100644 --- a/prql-compiler/src/sql/translator.rs +++ b/prql-compiler/src/sql/translator.rs @@ -15,7 +15,7 @@ use sqlparser::ast::{self as sql_ast, Select, SetExpr, TableWithJoins}; use std::collections::HashMap; use crate::ast::{DialectHandler, Literal}; -use crate::ir::{Expr, ExprKind, Query, Table, TableExpr, Transform}; +use crate::ir::{Expr, ExprKind, IrFold, Query, Table, TableCounter, TableExpr, Transform}; use super::anchor::AnchorContext; use super::codegen::*; @@ -23,6 +23,7 @@ use super::codegen::*; pub(super) struct Context { pub dialect: Box, pub anchor: AnchorContext, + pub omit_ident_prefix: bool, } /// Translate a PRQL AST into a SQL string. @@ -50,7 +51,11 @@ pub fn translate_query(query: Query) -> Result { let (anchor, query) = AnchorContext::of(query); - let mut context = Context { dialect, anchor }; + let mut context = Context { + dialect, + anchor, + omit_ident_prefix: false, + }; // extract tables and the pipeline let tables = into_tables(query.expr, query.tables, &mut context)?; @@ -121,6 +126,10 @@ fn sql_query_of_atomic_query( pipeline: Vec, context: &mut Context, ) -> Result { + let mut counter = TableCounter::default(); + let pipeline = counter.fold_transforms(pipeline)?; + context.omit_ident_prefix = counter.count() == 1; + let select = context.anchor.determine_select_columns(&pipeline); let mut from = pipeline @@ -372,6 +381,7 @@ mod test { let mut context = Context { dialect: Box::new(GenericDialect {}), anchor, + omit_ident_prefix: false, }; let table = Table {