Skip to content

Commit

Permalink
Omit table names in idents when there is only one table in SELECT
Browse files Browse the repository at this point in the history
  • Loading branch information
MarinPostma committed Oct 29, 2022
1 parent 672bff8 commit 40a1360
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 36 deletions.
2 changes: 2 additions & 0 deletions prql-compiler/src/ir/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod expr;
mod id_gen;
mod ir_fold;
mod table_counter;

pub use expr::{Expr, ExprKind, UnOp};
pub use id_gen::IdGenerator;
pub use ir_fold::*;
pub use table_counter::TableCounter;

use serde::{Deserialize, Serialize};

Expand Down
27 changes: 27 additions & 0 deletions prql-compiler/src/ir/table_counter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use std::collections::HashSet;

use super::{IrFold, TId, Transform};

/// Folder that counts the number of table referenced in a PRQL query.
#[derive(Debug, Default)]
pub struct TableCounter {
tables: HashSet<TId>,
}

impl TableCounter {
pub fn count(&self) -> usize {
self.tables.len()
}
}

impl IrFold for TableCounter {
fn fold_transforms(&mut self, transforms: Vec<Transform>) -> anyhow::Result<Vec<Transform>> {
for transform in &transforms {
if let Transform::Join { with: tid, .. } | Transform::From(tid) = transform {
self.tables.insert(*tid);
}
}

Ok(transforms)
}
}
48 changes: 21 additions & 27 deletions prql-compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ mod test {
}

#[test]
#[ignore]
fn test_quoting() {
// GH-#822
assert_display_snapshot!((compile(r###"
Expand All @@ -238,7 +237,7 @@ join some_schema.tablename [~id]
"###).unwrap()), @r###"
WITH "UPPER" AS (
SELECT
lower.*
*
FROM
lower
)
Expand Down Expand Up @@ -287,7 +286,7 @@ select `first name`

assert_display_snapshot!((compile(query).unwrap()), @r###"
SELECT
invoices.*
*
FROM
invoices
ORDER BY
Expand All @@ -298,7 +297,6 @@ select `first name`
}

#[test]
#[ignore]
fn test_ranges() {
let query = r###"
from employees
Expand All @@ -307,7 +305,7 @@ select `first name`

assert_display_snapshot!((compile(query).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees
WHERE
Expand All @@ -329,7 +327,7 @@ select `first name`

assert_display_snapshot!((compile(query).unwrap()), @r###"
SELECT
events.*
*
FROM
events
WHERE
Expand All @@ -339,7 +337,6 @@ select `first name`
}

#[test]
#[ignore]
fn test_interval() {
let query = r###"
from projects
Expand All @@ -348,7 +345,7 @@ select `first name`

assert_display_snapshot!((compile(query).unwrap()), @r###"
SELECT
projects.*,
*,
start + INTERVAL 10 DAY AS first_check_in
FROM
projects
Expand All @@ -367,7 +364,7 @@ select `first name`
]
"###).unwrap()), @r###"
SELECT
to_do_empty_table.*,
*,
DATE '2011-02-01' AS date,
TIMESTAMP '2011-02-01T10:00' AS timestamp,
TIME '14:00' AS time
Expand Down Expand Up @@ -642,7 +639,6 @@ select `first name`
}

#[test]
#[ignore]
fn test_filter() {
// https://github.com/prql/prql/issues/469
let query = r###"
Expand All @@ -657,7 +653,7 @@ select `first name`
filter age > 25 and age < 40
"###).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees
WHERE
Expand All @@ -671,7 +667,7 @@ select `first name`
filter age < 40
"###).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees
WHERE
Expand All @@ -681,7 +677,6 @@ select `first name`
}

#[test]
#[ignore]
fn test_nulls() {
assert_display_snapshot!((compile(r###"
from employees
Expand All @@ -699,7 +694,7 @@ select `first name`
derive amount = amount + 2 ?? 3 * 5
"###).unwrap()), @r###"
SELECT
employees.*,
*,
COALESCE(amount + 2, 3 * 5) AS amount
FROM
employees
Expand All @@ -711,7 +706,7 @@ select `first name`
filter first_name == null and null == last_name
"###).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees
WHERE
Expand All @@ -725,7 +720,7 @@ select `first name`
filter first_name != null and null != last_name
"###).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees
WHERE
Expand All @@ -742,7 +737,7 @@ select `first name`
take ..10
"###).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees
LIMIT
Expand All @@ -754,7 +749,7 @@ select `first name`
take 5..10
"###).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees
LIMIT
Expand All @@ -766,7 +761,7 @@ select `first name`
take 5..
"###).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees OFFSET 4
"###);
Expand All @@ -776,7 +771,7 @@ select `first name`
take 5..5
"###).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees
LIMIT
Expand All @@ -790,7 +785,7 @@ select `first name`
take 1..5
"###).unwrap()), @r###"
SELECT
employees.*
*
FROM
employees
LIMIT
Expand Down Expand Up @@ -1434,15 +1429,14 @@ take 20

assert_display_snapshot!((compile(query).unwrap()), @r###"
SELECT
github_json.*,
github_json.`event.type` AS event_type_dotted
*,
`event.type` AS event_type_dotted
FROM
github_json
"###);
}

#[test]
#[ignore]
fn test_ident_escaping() {
// Generic
let query = r###"
Expand All @@ -1452,7 +1446,7 @@ take 20

assert_display_snapshot!((compile(query).unwrap()), @r###"
SELECT
"anim""ls".*,
*,
"BeeName" AS "čebela",
"bear's_name" AS medved
FROM
Expand All @@ -1469,7 +1463,7 @@ take 20

assert_display_snapshot!((compile(query).unwrap()), @r###"
SELECT
`anim"ls`.*,
*,
`BeeName` AS `čebela`,
`bear's_name` AS medved
FROM
Expand All @@ -1488,7 +1482,7 @@ take 20
assert_display_snapshot!(sql,
@r###"
SELECT
employees.*,
*,
true AS always_true
FROM
employees
Expand Down
2 changes: 1 addition & 1 deletion prql-compiler/src/sql/anchor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl AnchorContext {
match transform {
Transform::From(tid) => {
let table_def = &self.table_defs.get(tid).unwrap();
columns = table_def.columns.iter().map(|c| c.id).collect();
columns = table_def.columns.iter().map(|c| dbg!(c.id)).collect();
}
Transform::Select(cols) => columns = cols.clone(),
Transform::Aggregate(cols) => columns = cols.clone(),
Expand Down
15 changes: 9 additions & 6 deletions prql-compiler/src/sql/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,14 +414,17 @@ pub(super) fn translate_ident(
context: &Context,
) -> Vec<sql_ast::Ident> {
let mut parts = Vec::with_capacity(4);
if let Some(relation) = relation_name {
// Special-case this for BigQuery, Ref #852
if matches!(context.dialect.dialect(), Dialect::BigQuery) {
parts.push(relation);
} else {
parts.extend(relation.split('.').map(|s| s.to_string()));
if !context.omit_ident_prefix || column.is_none() {
if let Some(relation) = relation_name {
// Special-case this for BigQuery, Ref #852
if matches!(context.dialect.dialect(), Dialect::BigQuery) {
parts.push(relation);
} else {
parts.extend(relation.split('.').map(|s| s.to_string()));
}
}
}

parts.extend(column);

parts
Expand Down
14 changes: 12 additions & 2 deletions prql-compiler/src/sql/translator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ use sqlparser::ast::{self as sql_ast, Select, SetExpr, TableWithJoins};
use std::collections::HashMap;

use crate::ast::{DialectHandler, Literal};
use crate::ir::{Expr, ExprKind, Query, Table, TableExpr, Transform};
use crate::ir::{Expr, ExprKind, IrFold, Query, Table, TableCounter, TableExpr, Transform};

use super::anchor::AnchorContext;
use super::codegen::*;

pub(super) struct Context {
pub dialect: Box<dyn DialectHandler>,
pub anchor: AnchorContext,
pub omit_ident_prefix: bool,
}

/// Translate a PRQL AST into a SQL string.
Expand Down Expand Up @@ -50,7 +51,11 @@ pub fn translate_query(query: Query) -> Result<sql_ast::Query> {

let (anchor, query) = AnchorContext::of(query);

let mut context = Context { dialect, anchor };
let mut context = Context {
dialect,
anchor,
omit_ident_prefix: false,
};

// extract tables and the pipeline
let tables = into_tables(query.expr, query.tables, &mut context)?;
Expand Down Expand Up @@ -121,6 +126,10 @@ fn sql_query_of_atomic_query(
pipeline: Vec<Transform>,
context: &mut Context,
) -> Result<sql_ast::Query> {
let mut counter = TableCounter::default();
let pipeline = counter.fold_transforms(pipeline)?;
context.omit_ident_prefix = counter.count() == 1;

let select = context.anchor.determine_select_columns(&pipeline);

let mut from = pipeline
Expand Down Expand Up @@ -372,6 +381,7 @@ mod test {
let mut context = Context {
dialect: Box::new(GenericDialect {}),
anchor,
omit_ident_prefix: false,
};

let table = Table {
Expand Down

0 comments on commit 40a1360

Please sign in to comment.