From ac4b1147dc0a965c4c23e4a309974fa7b94884c4 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sat, 7 Sep 2024 17:39:37 +0400 Subject: [PATCH] feat(python): Support shortcut eval of common boolean filters in SQL interface "WHERE" clause (#18571) --- crates/polars-sql/src/context.rs | 24 +- crates/polars-sql/src/function_registry.rs | 2 +- crates/polars-sql/src/keywords.rs | 8 +- crates/polars-sql/src/lib.rs | 1 + crates/polars-sql/src/sql_expr.rs | 223 ++---------------- crates/polars-sql/src/types.rs | 208 ++++++++++++++++ .../tests/unit/sql/test_miscellaneous.py | 26 ++ 7 files changed, 279 insertions(+), 213 deletions(-) create mode 100644 crates/polars-sql/src/types.rs diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 58120da0002c..23ffb25070fa 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -847,8 +847,28 @@ impl SQLContext { expr: &Option, ) -> PolarsResult { if let Some(expr) = expr { - let schema = Some(self.get_frame_schema(&mut lf)?); - let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?; + let schema = self.get_frame_schema(&mut lf)?; + + // shortcut filter evaluation if given expression is just TRUE or FALSE + let (all_true, all_false) = match expr { + SQLExpr::Value(SQLValue::Boolean(b)) => (*b, !*b), + SQLExpr::BinaryOp { left, op, right } => match (&**left, &**right, op) { + (SQLExpr::Value(a), SQLExpr::Value(b), BinaryOperator::Eq) => (a == b, a != b), + (SQLExpr::Value(a), SQLExpr::Value(b), BinaryOperator::NotEq) => { + (a != b, a == b) + }, + _ => (false, false), + }, + _ => (false, false), + }; + if all_true { + return Ok(lf); + } else if all_false { + return Ok(DataFrame::empty_with_schema(schema.as_ref()).lazy()); + } + + // ...otherwise parse and apply the filter as normal + let mut filter_expression = parse_sql_expr(expr, self, Some(schema).as_deref())?; if filter_expression.clone().meta().has_multiple_outputs() { filter_expression = all_horizontal([filter_expression])?; } diff --git a/crates/polars-sql/src/function_registry.rs b/crates/polars-sql/src/function_registry.rs index c85f8307af73..aa693025b072 100644 --- a/crates/polars-sql/src/function_registry.rs +++ b/crates/polars-sql/src/function_registry.rs @@ -1,4 +1,4 @@ -//! This module defines the function registry and user defined functions. +//! This module defines a FunctionRegistry for supported SQL functions and UDFs. use polars_error::{polars_bail, PolarsResult}; use polars_plan::prelude::udf::UserDefinedFunction; diff --git a/crates/polars-sql/src/keywords.rs b/crates/polars-sql/src/keywords.rs index 1442a91cd89f..990bc046aa5b 100644 --- a/crates/polars-sql/src/keywords.rs +++ b/crates/polars-sql/src/keywords.rs @@ -1,10 +1,8 @@ -//! Keywords that are supported by Polars SQL -//! -//! This is useful for syntax highlighting +//! Keywords that are supported by the Polars SQL interface. //! //! This module defines: -//! - all Polars SQL keywords [`all_keywords`] -//! - all of polars SQL functions [`all_functions`] +//! - all recognised Polars SQL keywords [`all_keywords`] +//! - all recognised Polars SQL functions [`all_functions`] use crate::functions::PolarsSQLFunctions; use crate::table_functions::PolarsTableFunctions; diff --git a/crates/polars-sql/src/lib.rs b/crates/polars-sql/src/lib.rs index a811a4cfad9b..528f21eafaf2 100644 --- a/crates/polars-sql/src/lib.rs +++ b/crates/polars-sql/src/lib.rs @@ -7,6 +7,7 @@ mod functions; pub mod keywords; mod sql_expr; mod table_functions; +mod types; pub use context::SQLContext; pub use sql_expr::sql_expr; diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 6d18a75c42fe..148a7fe5735e 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -1,3 +1,11 @@ +//! Expressions that are supported by the Polars SQL interface. +//! +//! This is useful for syntax highlighting +//! +//! This module defines: +//! - all Polars SQL keywords [`all_keywords`] +//! - all of polars SQL functions [`all_functions`] + use std::fmt::Display; use std::ops::Div; @@ -9,216 +17,39 @@ use polars_plan::prelude::LiteralValue::Null; use polars_time::Duration; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; -use regex::{Regex, RegexBuilder}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[cfg(feature = "dtype-decimal")] -use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::{ - ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind, + BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, - Interval, ObjectName, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField, + Interval, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField, UnaryOperator, Value as SQLValue, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; use crate::functions::SQLFunctionVisitor; +use crate::types::{ + bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars, +}; use crate::SQLContext; -static DATETIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); -static DATE_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); -static TIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); - -fn is_iso_datetime(value: &str) -> bool { - let dtm_regex = DATETIME_LITERAL_RE.get_or_init(|| { - RegexBuilder::new( - r"^\d{4}-[01]\d-[0-3]\d[ T](?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$", - ) - .build() - .unwrap() - }); - dtm_regex.is_match(value) -} - -fn is_iso_date(value: &str) -> bool { - let dt_regex = DATE_LITERAL_RE.get_or_init(|| { - RegexBuilder::new(r"^\d{4}-[01]\d-[0-3]\d$") - .build() - .unwrap() - }); - dt_regex.is_match(value) -} - -fn is_iso_time(value: &str) -> bool { - let tm_regex = TIME_LITERAL_RE.get_or_init(|| { - RegexBuilder::new(r"^(?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$") - .build() - .unwrap() - }); - tm_regex.is_match(value) -} - #[inline] #[cold] #[must_use] +/// Convert a Display-able error to PolarsError::SQLInterface pub fn to_sql_interface_err(err: impl Display) -> PolarsError { PolarsError::SQLInterface(err.to_string().into()) } -fn timeunit_from_precision(prec: &Option) -> PolarsResult { - Ok(match prec { - None => TimeUnit::Microseconds, - Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds, - Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, - Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, - Some(n) => { - polars_bail!(SQLSyntax: "invalid temporal type precision (expected 1-9, found {})", n) - }, - }) -} - -pub(crate) fn map_sql_polars_datatype(dtype: &SQLDataType) -> PolarsResult { - Ok(match dtype { - // --------------------------------- - // array/list - // --------------------------------- - SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type)) - | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type, _)) => { - DataType::List(Box::new(map_sql_polars_datatype(inner_type)?)) - }, - - // --------------------------------- - // binary - // --------------------------------- - SQLDataType::Bytea - | SQLDataType::Bytes(_) - | SQLDataType::Binary(_) - | SQLDataType::Blob(_) - | SQLDataType::Varbinary(_) => DataType::Binary, - - // --------------------------------- - // boolean - // --------------------------------- - SQLDataType::Boolean | SQLDataType::Bool => DataType::Boolean, - - // --------------------------------- - // signed integer - // --------------------------------- - SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32, - SQLDataType::Int2(_) | SQLDataType::SmallInt(_) => DataType::Int16, - SQLDataType::Int4(_) | SQLDataType::MediumInt(_) => DataType::Int32, - SQLDataType::Int8(_) | SQLDataType::BigInt(_) => DataType::Int64, - SQLDataType::TinyInt(_) => DataType::Int8, - - // --------------------------------- - // unsigned integer: the following do not map to PostgreSQL types/syntax, but - // are enabled for wider compatibility (eg: "CAST(col AS BIGINT UNSIGNED)"). - // --------------------------------- - SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, // see also: "custom" types below - SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => DataType::UInt32, - SQLDataType::UnsignedInt2(_) | SQLDataType::UnsignedSmallInt(_) => DataType::UInt16, - SQLDataType::UnsignedInt4(_) | SQLDataType::UnsignedMediumInt(_) => DataType::UInt32, - SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) | SQLDataType::UInt8 => { - DataType::UInt64 - }, - - // --------------------------------- - // float - // --------------------------------- - SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => { - DataType::Float64 - }, - SQLDataType::Float(n_bytes) => match n_bytes { - Some(n) if (1u64..=24u64).contains(n) => DataType::Float32, - Some(n) if (25u64..=53u64).contains(n) => DataType::Float64, - Some(n) => { - polars_bail!(SQLSyntax: "unsupported `float` size (expected a value between 1 and 53, found {})", n) - }, - None => DataType::Float64, - }, - SQLDataType::Float4 | SQLDataType::Real => DataType::Float32, - - // --------------------------------- - // decimal - // --------------------------------- - #[cfg(feature = "dtype-decimal")] - SQLDataType::Dec(info) | SQLDataType::Decimal(info) | SQLDataType::Numeric(info) => { - match *info { - ExactNumberInfo::PrecisionAndScale(p, s) => { - DataType::Decimal(Some(p as usize), Some(s as usize)) - }, - ExactNumberInfo::Precision(p) => DataType::Decimal(Some(p as usize), Some(0)), - ExactNumberInfo::None => DataType::Decimal(Some(38), Some(9)), - } - }, - - // --------------------------------- - // temporal - // --------------------------------- - SQLDataType::Date => DataType::Date, - SQLDataType::Interval => DataType::Duration(TimeUnit::Microseconds), - SQLDataType::Time(_, tz) => match tz { - TimezoneInfo::None => DataType::Time, - _ => { - polars_bail!(SQLInterface: "`time` with timezone is not supported; found tz={}", tz) - }, - }, - SQLDataType::Datetime(prec) => DataType::Datetime(timeunit_from_precision(prec)?, None), - SQLDataType::Timestamp(prec, tz) => match tz { - TimezoneInfo::None => DataType::Datetime(timeunit_from_precision(prec)?, None), - _ => { - polars_bail!(SQLInterface: "`timestamp` with timezone is not (yet) supported") - }, - }, - - // --------------------------------- - // string - // --------------------------------- - SQLDataType::Char(_) - | SQLDataType::CharVarying(_) - | SQLDataType::Character(_) - | SQLDataType::CharacterVarying(_) - | SQLDataType::Clob(_) - | SQLDataType::String(_) - | SQLDataType::Text - | SQLDataType::Uuid - | SQLDataType::Varchar(_) => DataType::String, - - // --------------------------------- - // custom - // --------------------------------- - SQLDataType::Custom(ObjectName(idents), _) => match idents.as_slice() { - [Ident { value, .. }] => match value.to_lowercase().as_str() { - // these integer types are not supported by the PostgreSQL core distribution, - // but they ARE available via `pguint` (https://github.com/petere/pguint), an - // extension maintained by one of the PostgreSQL core developers. - "uint1" => DataType::UInt8, - "uint2" => DataType::UInt16, - "uint4" | "uint" => DataType::UInt32, - "uint8" => DataType::UInt64, - // `pguint` also provides a 1 byte (8bit) integer type alias - "int1" => DataType::Int8, - _ => { - polars_bail!(SQLInterface: "datatype {:?} is not currently supported", value) - }, - }, - _ => { - polars_bail!(SQLInterface: "datatype {:?} is not currently supported", idents) - }, - }, - _ => { - polars_bail!(SQLInterface: "datatype {:?} is not currently supported", dtype) - }, - }) -} - #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)] +/// Categorises the type of (allowed) subquery constraint pub enum SubqueryRestriction { - // SingleValue, + /// Subquery must return a single column SingleColumn, // SingleRow, + // SingleValue, // Any } @@ -889,7 +720,7 @@ impl SQLExprVisitor<'_> { if dtype == &SQLDataType::JSON { return Ok(expr.str().json_decode(None, None)); } - let polars_type = map_sql_polars_datatype(dtype)?; + let polars_type = map_sql_dtype_to_polars(dtype)?; Ok(match cast_kind { CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type), CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type), @@ -1319,24 +1150,6 @@ pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr { } } -fn bitstring_to_bytes_literal(b: &String) -> PolarsResult { - let n_bits = b.len(); - if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 { - polars_bail!( - SQLSyntax: - "bit string literal should contain only 0s and 1s and have length <= 64; found '{}' with length {}", b, n_bits - ) - } - let s = b.as_str(); - Ok(lit(match n_bits { - 0 => b"".to_vec(), - 1..=8 => u8::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - 9..=16 => u16::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - 17..=32 => u32::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - _ => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - })) -} - pub(crate) fn resolve_compound_identifier( ctx: &mut SQLContext, idents: &[Ident], diff --git a/crates/polars-sql/src/types.rs b/crates/polars-sql/src/types.rs new file mode 100644 index 000000000000..800ead8c233e --- /dev/null +++ b/crates/polars-sql/src/types.rs @@ -0,0 +1,208 @@ +//! This module supports mapping SQL datatypes to Polars datatypes. +//! +//! It also provides utility functions for working with SQL datatypes. +use polars_core::datatypes::{DataType, TimeUnit}; +use polars_core::export::regex::{Regex, RegexBuilder}; +use polars_error::{polars_bail, PolarsResult}; +use polars_plan::dsl::{lit, Expr}; +use sqlparser::ast::{ + ArrayElemTypeDef, DataType as SQLDataType, ExactNumberInfo, Ident, ObjectName, TimezoneInfo, +}; + +static DATETIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); +static DATE_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); +static TIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); + +pub fn bitstring_to_bytes_literal(b: &String) -> PolarsResult { + let n_bits = b.len(); + if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 { + polars_bail!( + SQLSyntax: + "bit string literal should contain only 0s and 1s and have length <= 64; found '{}' with length {}", b, n_bits + ) + } + let s = b.as_str(); + Ok(lit(match n_bits { + 0 => b"".to_vec(), + 1..=8 => u8::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + 9..=16 => u16::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + 17..=32 => u32::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + _ => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + })) +} + +pub fn is_iso_datetime(value: &str) -> bool { + let dtm_regex = DATETIME_LITERAL_RE.get_or_init(|| { + RegexBuilder::new( + r"^\d{4}-[01]\d-[0-3]\d[ T](?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$", + ) + .build() + .unwrap() + }); + dtm_regex.is_match(value) +} + +pub fn is_iso_date(value: &str) -> bool { + let dt_regex = DATE_LITERAL_RE.get_or_init(|| { + RegexBuilder::new(r"^\d{4}-[01]\d-[0-3]\d$") + .build() + .unwrap() + }); + dt_regex.is_match(value) +} + +pub fn is_iso_time(value: &str) -> bool { + let tm_regex = TIME_LITERAL_RE.get_or_init(|| { + RegexBuilder::new(r"^(?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$") + .build() + .unwrap() + }); + tm_regex.is_match(value) +} + +fn timeunit_from_precision(prec: &Option) -> PolarsResult { + Ok(match prec { + None => TimeUnit::Microseconds, + Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds, + Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, + Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, + Some(n) => { + polars_bail!(SQLSyntax: "invalid temporal type precision (expected 1-9, found {})", n) + }, + }) +} + +pub(crate) fn map_sql_dtype_to_polars(dtype: &SQLDataType) -> PolarsResult { + Ok(match dtype { + // --------------------------------- + // array/list + // --------------------------------- + SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type)) + | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type, _)) => { + DataType::List(Box::new(map_sql_dtype_to_polars(inner_type)?)) + }, + + // --------------------------------- + // binary + // --------------------------------- + SQLDataType::Bytea + | SQLDataType::Bytes(_) + | SQLDataType::Binary(_) + | SQLDataType::Blob(_) + | SQLDataType::Varbinary(_) => DataType::Binary, + + // --------------------------------- + // boolean + // --------------------------------- + SQLDataType::Boolean | SQLDataType::Bool => DataType::Boolean, + + // --------------------------------- + // signed integer + // --------------------------------- + SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32, + SQLDataType::Int2(_) | SQLDataType::SmallInt(_) => DataType::Int16, + SQLDataType::Int4(_) | SQLDataType::MediumInt(_) => DataType::Int32, + SQLDataType::Int8(_) | SQLDataType::BigInt(_) => DataType::Int64, + SQLDataType::TinyInt(_) => DataType::Int8, + + // --------------------------------- + // unsigned integer: the following do not map to PostgreSQL types/syntax, but + // are enabled for wider compatibility (eg: "CAST(col AS BIGINT UNSIGNED)"). + // --------------------------------- + SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, // see also: "custom" types below + SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => DataType::UInt32, + SQLDataType::UnsignedInt2(_) | SQLDataType::UnsignedSmallInt(_) => DataType::UInt16, + SQLDataType::UnsignedInt4(_) | SQLDataType::UnsignedMediumInt(_) => DataType::UInt32, + SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) | SQLDataType::UInt8 => { + DataType::UInt64 + }, + + // --------------------------------- + // float + // --------------------------------- + SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => { + DataType::Float64 + }, + SQLDataType::Float(n_bytes) => match n_bytes { + Some(n) if (1u64..=24u64).contains(n) => DataType::Float32, + Some(n) if (25u64..=53u64).contains(n) => DataType::Float64, + Some(n) => { + polars_bail!(SQLSyntax: "unsupported `float` size (expected a value between 1 and 53, found {})", n) + }, + None => DataType::Float64, + }, + SQLDataType::Float4 | SQLDataType::Real => DataType::Float32, + + // --------------------------------- + // decimal + // --------------------------------- + #[cfg(feature = "dtype-decimal")] + SQLDataType::Dec(info) | SQLDataType::Decimal(info) | SQLDataType::Numeric(info) => { + match *info { + ExactNumberInfo::PrecisionAndScale(p, s) => { + DataType::Decimal(Some(p as usize), Some(s as usize)) + }, + ExactNumberInfo::Precision(p) => DataType::Decimal(Some(p as usize), Some(0)), + ExactNumberInfo::None => DataType::Decimal(Some(38), Some(9)), + } + }, + + // --------------------------------- + // temporal + // --------------------------------- + SQLDataType::Date => DataType::Date, + SQLDataType::Interval => DataType::Duration(TimeUnit::Microseconds), + SQLDataType::Time(_, tz) => match tz { + TimezoneInfo::None => DataType::Time, + _ => { + polars_bail!(SQLInterface: "`time` with timezone is not supported; found tz={}", tz) + }, + }, + SQLDataType::Datetime(prec) => DataType::Datetime(timeunit_from_precision(prec)?, None), + SQLDataType::Timestamp(prec, tz) => match tz { + TimezoneInfo::None => DataType::Datetime(timeunit_from_precision(prec)?, None), + _ => { + polars_bail!(SQLInterface: "`timestamp` with timezone is not (yet) supported") + }, + }, + + // --------------------------------- + // string + // --------------------------------- + SQLDataType::Char(_) + | SQLDataType::CharVarying(_) + | SQLDataType::Character(_) + | SQLDataType::CharacterVarying(_) + | SQLDataType::Clob(_) + | SQLDataType::String(_) + | SQLDataType::Text + | SQLDataType::Uuid + | SQLDataType::Varchar(_) => DataType::String, + + // --------------------------------- + // custom + // --------------------------------- + SQLDataType::Custom(ObjectName(idents), _) => match idents.as_slice() { + [Ident { value, .. }] => match value.to_lowercase().as_str() { + // these integer types are not supported by the PostgreSQL core distribution, + // but they ARE available via `pguint` (https://github.com/petere/pguint), an + // extension maintained by one of the PostgreSQL core developers. + "uint1" => DataType::UInt8, + "uint2" => DataType::UInt16, + "uint4" | "uint" => DataType::UInt32, + "uint8" => DataType::UInt64, + // `pguint` also provides a 1 byte (8bit) integer type alias + "int1" => DataType::Int8, + _ => { + polars_bail!(SQLInterface: "datatype {:?} is not currently supported", value) + }, + }, + _ => { + polars_bail!(SQLInterface: "datatype {:?} is not currently supported", idents) + }, + }, + _ => { + polars_bail!(SQLInterface: "datatype {:?} is not currently supported", dtype) + }, + }) +} diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index 77aa60e08af8..95ba8461bebe 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -2,6 +2,7 @@ from datetime import date from pathlib import Path +from typing import TYPE_CHECKING, Any import pytest @@ -9,6 +10,9 @@ from polars.exceptions import SQLInterfaceError, SQLSyntaxError from polars.testing import assert_frame_equal +if TYPE_CHECKING: + from polars.datatypes import DataType + @pytest.fixture def foods_ipc_path() -> Path: @@ -53,6 +57,28 @@ def test_any_all() -> None: } +@pytest.mark.parametrize( + ("data", "schema"), + [ + ({"x": [1, 2, 3, 4]}, None), + ({"x": [9, 8, 7, 6]}, {"x": pl.Int8}), + ({"x": ["aa", "bb"]}, {"x": pl.Struct}), + ({"x": [None, None], "y": [None, None]}, {"x": pl.Date, "y": pl.Float64}), + ], +) +def test_boolean_where_clauses( + data: dict[str, Any], schema: dict[str, DataType] | None +) -> None: + df = pl.DataFrame(data=data, schema=schema) + empty_df = df.clear() + + for true in ("TRUE", "1=1", "2 == 2", "'xx' = 'xx'", "TRUE AND 1=1"): + assert_frame_equal(df, df.sql(f"SELECT * FROM self WHERE {true}")) + + for false in ("false", "1!=1", "2 != 2", "'xx' != 'xx'", "FALSE OR 1!=1"): + assert_frame_equal(empty_df, df.sql(f"SELECT * FROM self WHERE {false}")) + + def test_count() -> None: df = pl.DataFrame( {