From c01850d0c4ba3a99ffb6e70b216d95186349408b Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Fri, 11 Aug 2023 10:41:56 -0400 Subject: [PATCH] Add SnowparkDataset and date_to_utc_timestamp support across dialects (#374) * Update SqlDataFrame to handle source tables with quotes and dots like "DEMO_DATA"."DEMOS"."MOVIES" * Add date_to_utc_timestamp support across dialects * Add SnowparkDataset --- Cargo.lock | 1 + .../vegafusion/vegafusion/dataset/snowpark.py | 101 +++++++ vegafusion-sql/Cargo.toml | 1 + vegafusion-sql/src/compile/scalar.rs | 24 +- vegafusion-sql/src/dataframe/mod.rs | 17 +- vegafusion-sql/src/dialect/mod.rs | 48 +++- .../transforms/date_to_utc_timestamp.rs | 272 +++++++++++++++++- vegafusion-sql/tests/expected/select.toml | 42 +++ vegafusion-sql/tests/test_select.rs | 73 +++++ vegafusion-sql/tests/utils/mod.rs | 2 +- 10 files changed, 572 insertions(+), 9 deletions(-) create mode 100644 python/vegafusion/vegafusion/dataset/snowpark.py diff --git a/Cargo.lock b/Cargo.lock index 1dd6c3ec1..440f4db37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4477,6 +4477,7 @@ dependencies = [ "arrow", "async-std", "async-trait", + "chrono", "datafusion", "datafusion-common", "datafusion-expr", diff --git a/python/vegafusion/vegafusion/dataset/snowpark.py b/python/vegafusion/vegafusion/dataset/snowpark.py new file mode 100644 index 000000000..e20df8c8c --- /dev/null +++ b/python/vegafusion/vegafusion/dataset/snowpark.py @@ -0,0 +1,101 @@ +import logging +import pyarrow as pa +from .sql import SqlDataset +from snowflake.snowpark import Table as SnowparkTable +from snowflake.snowpark.types import DataType as SnowparkDataType +from typing import Dict + +from ..transformer import to_arrow_table + +SNOWPARK_TO_PYARROW_TYPES: Dict[SnowparkDataType, pa.DataType] = {} + + +def get_snowpark_to_pyarrow_types(): + if not SNOWPARK_TO_PYARROW_TYPES: + import snowflake.snowpark.types as sp_types + + SNOWPARK_TO_PYARROW_TYPES.update( + { + sp_types.LongType: pa.int64(), + sp_types.BinaryType: pa.binary(), + sp_types.BooleanType: pa.bool_(), + sp_types.ByteType: pa.int8(), + sp_types.StringType: pa.string(), + sp_types.DateType: pa.date32(), + sp_types.DoubleType: pa.float64(), + sp_types.FloatType: pa.float32(), + sp_types.IntegerType: pa.int32(), + sp_types.ShortType: pa.int16(), + sp_types.TimestampType: pa.timestamp("ms"), + } + ) + return SNOWPARK_TO_PYARROW_TYPES + + +def snowflake_field_to_pyarrow_type(provided_type: SnowparkDataType) -> pa.DataType: + """ + Converts Snowflake types to PyArrow equivalent types, raising a ValueError if they aren't comparable. + See https://docs.snowflake.com/en/sql-reference/intro-summary-data-types + """ + from snowflake.snowpark.types import DecimalType as SnowparkDecimalType + + type_map = get_snowpark_to_pyarrow_types() + if provided_type.__class__ in type_map: + return type_map[provided_type.__class__] + + if isinstance(provided_type, SnowparkDecimalType): + return pa.decimal128(provided_type.precision, provided_type.scale) + else: + raise ValueError(f"Unsupported Snowpark type: {provided_type}") + + +def snowpark_table_to_pyarrow_schema(table: SnowparkTable) -> pa.Schema: + schema_fields = {} + for name, field in zip(table.schema.names, table.schema.fields): + normalised_name = name.strip('"') + schema_fields[normalised_name] = snowflake_field_to_pyarrow_type(field.datatype) + return pa.schema(schema_fields) + + +class SnowparkDataset(SqlDataset): + def dialect(self) -> str: + return "snowflake" + + def __init__( + self, table: SnowparkTable, fallback: bool = True, verbose: bool = False + ): + if not isinstance(table, SnowparkTable): + raise ValueError( + f"SnowparkDataset accepts a snowpark Table. Received: {type(table)}" + ) + self._table = table + self._session = table._session + + self._fallback = fallback + self._verbose = verbose + self._table_name = table.table_name + self._table_schema = snowpark_table_to_pyarrow_schema(self._table) + + self.logger = logging.getLogger("SnowparkDataset") + + def table_name(self) -> str: + return self._table_name + + def table_schema(self) -> pa.Schema: + return self._table_schema + + def fetch_query(self, query: str, schema: pa.Schema) -> pa.Table: + self.logger.info(f"Snowflake Query:\n{query}\n") + if self._verbose: + print(f"Snowflake Query:\n{query}\n") + + sp_df = self._session.sql(query) + batches = [] + for pd_batch in sp_df.to_pandas_batches(): + pa_tbl = to_arrow_table(pd_batch).cast(schema) + batches.extend(pa_tbl.to_batches()) + + return pa.Table.from_batches(batches, schema) + + def fallback(self) -> bool: + return self._fallback diff --git a/vegafusion-sql/Cargo.toml b/vegafusion-sql/Cargo.toml index 1d0d89a9c..92a2eb472 100644 --- a/vegafusion-sql/Cargo.toml +++ b/vegafusion-sql/Cargo.toml @@ -12,6 +12,7 @@ datafusion-conn = [ "datafusion", "tempfile", "reqwest", "reqwest-retry", "reqwe async-trait = "0.1.53" deterministic-hash = "1.0.1" log = "0.4.17" +chrono = "0.4.23" [dev-dependencies] rstest = "0.17.0" diff --git a/vegafusion-sql/src/compile/scalar.rs b/vegafusion-sql/src/compile/scalar.rs index 640d3652e..6051a3778 100644 --- a/vegafusion-sql/src/compile/scalar.rs +++ b/vegafusion-sql/src/compile/scalar.rs @@ -1,3 +1,4 @@ +use crate::compile::data_type::ToSqlDataType; use crate::dialect::Dialect; use arrow::datatypes::DataType; use datafusion_common::scalar::ScalarValue; @@ -5,6 +6,7 @@ use sqlparser::ast::{ Expr as SqlExpr, Function as SqlFunction, FunctionArg as SqlFunctionArg, FunctionArgExpr, Ident, ObjectName as SqlObjectName, Value as SqlValue, }; +use std::ops::Add; use vegafusion_common::error::{Result, VegaFusionError}; pub trait ToSqlScalar { @@ -159,9 +161,7 @@ impl ToSqlScalar for ScalarValue { order_by: Default::default(), })) } - ScalarValue::Date32(_) => Err(VegaFusionError::internal( - "Date32 cannot be converted to SQL", - )), + ScalarValue::Date32(v) => date32_to_date(v, dialect), ScalarValue::Date64(_) => Err(VegaFusionError::internal( "Date64 cannot be converted to SQL", )), @@ -246,3 +246,21 @@ fn ms_to_timestamp(v: i64) -> SqlExpr { order_by: Default::default(), }) } + +fn date32_to_date(days: &Option, dialect: &Dialect) -> Result { + let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + match days { + None => Ok(SqlExpr::Cast { + expr: Box::new(ScalarValue::Utf8(None).to_sql(dialect)?), + data_type: DataType::Date32.to_sql(dialect)?, + }), + Some(days) => { + let date = epoch.add(chrono::Duration::days(*days as i64)); + let date_str = date.format("%F").to_string(); + Ok(SqlExpr::Cast { + expr: Box::new(ScalarValue::from(date_str.as_str()).to_sql(dialect)?), + data_type: DataType::Date32.to_sql(dialect)?, + }) + } + } +} diff --git a/vegafusion-sql/src/dataframe/mod.rs b/vegafusion-sql/src/dataframe/mod.rs index 5fbef8b9b..5ee92268d 100644 --- a/vegafusion-sql/src/dataframe/mod.rs +++ b/vegafusion-sql/src/dataframe/mod.rs @@ -206,7 +206,20 @@ impl SqlDataFrame { .collect(); let select_items = columns.join(", "); - let table_ident = Ident::with_quote(conn.dialect().quote_style, table).to_string(); + // Replace special characters with underscores + let mut clean_table = table.to_string(); + for c in &['"', '\'', '.', '-'] { + clean_table = clean_table.replace(*c, "_"); + } + + let quote_style = conn.dialect().quote_style; + let table_ident = if !table.starts_with(quote_style) { + // Quote table + Ident::with_quote(conn.dialect().quote_style, table).to_string() + } else { + // If table name starts with the quote character, assume already quoted + table.to_string() + }; let query = parse_sql_query( &format!("select {select_items} from {table_ident}"), @@ -214,7 +227,7 @@ impl SqlDataFrame { )?; Ok(Self { - prefix: format!("{table}_"), + prefix: format!("{clean_table}_"), ctes: vec![query], schema: Arc::new(schema), conn, diff --git a/vegafusion-sql/src/dialect/mod.rs b/vegafusion-sql/src/dialect/mod.rs index cd3d0bd6c..4de066f1b 100644 --- a/vegafusion-sql/src/dialect/mod.rs +++ b/vegafusion-sql/src/dialect/mod.rs @@ -10,7 +10,11 @@ use crate::dialect::transforms::date_part_tz::{ DatePartTzWithDatePartAndAtTimezoneTransformer, DatePartTzWithExtractAndAtTimezoneTransformer, DatePartTzWithFromUtcAndDatePartTransformer, }; -use crate::dialect::transforms::date_to_utc_timestamp::DateToUtcTimestampWithCastAndAtTimeZoneTransformer; +use crate::dialect::transforms::date_to_utc_timestamp::{ + DateToUtcTimestampClickhouseTransformer, DateToUtcTimestampMySqlTransformer, + DateToUtcTimestampSnowflakeTransform, DateToUtcTimestampWithCastAndAtTimeZoneTransformer, + DateToUtcTimestampWithCastFunctionAtTransformer, DateToUtcTimestampWithFunctionTransformer, +}; use crate::dialect::transforms::date_trunc_tz::{ DateTruncTzClickhouseTransformer, DateTruncTzSnowflakeTransformer, DateTruncTzWithDateTruncAndAtTimezoneTransformer, @@ -321,6 +325,10 @@ impl Dialect { "date_trunc_tz", DateTruncTzWithDateTruncAndAtTimezoneTransformer::new_dyn(false), ), + ( + "date_to_utc_timestamp", + DateToUtcTimestampWithCastAndAtTimeZoneTransformer::new_dyn(), + ), ] .into_iter() .map(|(name, v)| (name.to_string(), v)) @@ -353,6 +361,7 @@ impl Dialect { DataType::Timestamp(TimeUnit::Millisecond, None), SqlDataType::Timestamp(None, TimezoneInfo::None), ), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), @@ -450,6 +459,10 @@ impl Dialect { "utc_timestamp_to_str", UtcTimestampToStrBigQueryTransformer::new_dyn(), ), + ( + "date_to_utc_timestamp", + DateToUtcTimestampWithFunctionTransformer::new_dyn("timestamp"), + ), ] .into_iter() .map(|(name, v)| (name.to_string(), v)) @@ -480,6 +493,7 @@ impl Dialect { DataType::Timestamp(TimeUnit::Millisecond, None), SqlDataType::Timestamp(None, TimezoneInfo::None), ), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), @@ -550,6 +564,10 @@ impl Dialect { ), ("date_part_tz", DatePartTzClickhouseTransformer::new_dyn()), ("date_trunc_tz", DateTruncTzClickhouseTransformer::new_dyn()), + ( + "date_to_utc_timestamp", + DateToUtcTimestampClickhouseTransformer::new_dyn(), + ), ] .into_iter() .map(|(name, v)| (name.to_string(), v)) @@ -580,6 +598,7 @@ impl Dialect { DataType::Timestamp(TimeUnit::Millisecond, None), SqlDataType::Timestamp(None, TimezoneInfo::None), ), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), @@ -691,6 +710,14 @@ impl Dialect { "utc_timestamp_to_str", UtcTimestampToStrDatabricksTransformer::new_dyn(), ), + ( + "date_to_utc_timestamp", + DateToUtcTimestampWithCastFunctionAtTransformer::new_dyn( + SqlDataType::Timestamp(None, TimezoneInfo::None), + "to_utc_timestamp", + false, + ), + ), ] .into_iter() .map(|(name, v)| (name.to_string(), v)) @@ -723,6 +750,7 @@ impl Dialect { DataType::Timestamp(TimeUnit::Millisecond, None), SqlDataType::Timestamp(None, TimezoneInfo::None), ), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), @@ -887,6 +915,7 @@ impl Dialect { DataType::Timestamp(TimeUnit::Millisecond, None), SqlDataType::Timestamp(None, TimezoneInfo::None), ), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), @@ -1042,6 +1071,7 @@ impl Dialect { DataType::Timestamp(TimeUnit::Millisecond, None), SqlDataType::Timestamp(None, TimezoneInfo::None), ), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), @@ -1113,6 +1143,10 @@ impl Dialect { StrToUtcTimestampMySqlTransformer::new_dyn(), ), ("date_part_tz", DatePartTzMySqlTransformer::new_dyn()), + ( + "date_to_utc_timestamp", + DateToUtcTimestampMySqlTransformer::new_dyn(), + ), ] .into_iter() .map(|(name, v)| (name.to_string(), v)) @@ -1138,6 +1172,7 @@ impl Dialect { (DataType::Float32, SqlDataType::Float(None)), (DataType::Float64, SqlDataType::Double), (DataType::Utf8, SqlDataType::Char(None)), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), @@ -1294,6 +1329,7 @@ impl Dialect { DataType::Timestamp(TimeUnit::Millisecond, None), SqlDataType::Timestamp(None, TimezoneInfo::None), ), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), @@ -1396,6 +1432,10 @@ impl Dialect { "date_trunc_tz", DateTruncTzWithDateTruncAndAtTimezoneTransformer::new_dyn(true), ), + ( + "date_to_utc_timestamp", + DateToUtcTimestampWithCastAndAtTimeZoneTransformer::new_dyn(), + ), ] .into_iter() .map(|(name, v)| (name.to_string(), v)) @@ -1426,6 +1466,7 @@ impl Dialect { DataType::Timestamp(TimeUnit::Millisecond, None), SqlDataType::Timestamp(None, TimezoneInfo::None), ), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), @@ -1539,6 +1580,10 @@ impl Dialect { "utc_timestamp_to_str", UtcTimestampToStrSnowflakeTransformer::new_dyn(), ), + ( + "date_to_utc_timestamp", + DateToUtcTimestampSnowflakeTransform::new_dyn(), + ), ] .into_iter() .map(|(name, v)| (name.to_string(), v)) @@ -1573,6 +1618,7 @@ impl Dialect { DataType::Timestamp(TimeUnit::Millisecond, None), SqlDataType::Timestamp(None, TimezoneInfo::None), ), + (DataType::Date32, SqlDataType::Date), ] .into_iter() .collect(), diff --git a/vegafusion-sql/src/dialect/transforms/date_to_utc_timestamp.rs b/vegafusion-sql/src/dialect/transforms/date_to_utc_timestamp.rs index f303b45f1..098d4f330 100644 --- a/vegafusion-sql/src/dialect/transforms/date_to_utc_timestamp.rs +++ b/vegafusion-sql/src/dialect/transforms/date_to_utc_timestamp.rs @@ -3,7 +3,9 @@ use crate::dialect::{Dialect, FunctionTransformer}; use datafusion_common::DFSchema; use datafusion_expr::Expr; use sqlparser::ast::{ - DataType as SqlDataType, Expr as SqlExpr, TimezoneInfo as SqlTimezoneInfo, Value as SqlValue, + DataType as SqlDataType, Expr as SqlExpr, Function as SqlFunction, + FunctionArg as SqlFunctionArg, FunctionArgExpr as SqlFunctionArgExpr, Ident as SqlIdent, + ObjectName as SqlObjectName, TimezoneInfo as SqlTimezoneInfo, Value as SqlValue, }; use std::sync::Arc; use vegafusion_common::error::{Result, VegaFusionError}; @@ -30,7 +32,7 @@ fn process_date_to_utc_timestamp_args( Ok((sql_arg0, time_zone)) } -/// Convert to_utc_timestamp(d, tz) -> +/// Convert date_to_utc_timestamp(d, tz) -> /// CAST(d as TIMESTAMP) AT TIME ZONE tz AT TIME ZONE 'UTC' /// or if tz = 'UTC' /// CAST(d as TIMESTAMP) @@ -68,3 +70,269 @@ impl FunctionTransformer for DateToUtcTimestampWithCastAndAtTimeZoneTransformer Ok(utc_timestamp) } } + +/// Convert date_to_utc_timestamp(d, tz) -> +/// CONVERT_TIMEZONE(tz, 'UTC', CAST(d as TIMESTAMP_NTZ)) +/// or if tz = 'UTC' +/// CAST(d as TIMESTAMP_NTZ) +#[derive(Clone, Debug)] +pub struct DateToUtcTimestampSnowflakeTransform; + +impl DateToUtcTimestampSnowflakeTransform { + pub fn new_dyn() -> Arc { + Arc::new(Self) + } +} + +impl FunctionTransformer for DateToUtcTimestampSnowflakeTransform { + fn transform(&self, args: &[Expr], dialect: &Dialect, schema: &DFSchema) -> Result { + let (date_arg, time_zone) = process_date_to_utc_timestamp_args(args, dialect, schema)?; + + let cast_timestamp_ntz_expr = SqlExpr::Cast { + expr: Box::new(date_arg), + data_type: SqlDataType::Custom( + SqlObjectName(vec![SqlIdent { + value: "timestamp_ntz".to_string(), + quote_style: None, + }]), + Vec::new(), + ), + }; + + if time_zone == "UTC" { + // No conversion needed + Ok(cast_timestamp_ntz_expr) + } else { + let convert_tz_expr = SqlExpr::Function(SqlFunction { + name: SqlObjectName(vec![SqlIdent { + value: "convert_timezone".to_string(), + quote_style: None, + }]), + args: vec![ + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(SqlExpr::Value( + SqlValue::SingleQuotedString(time_zone), + ))), + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(SqlExpr::Value( + SqlValue::SingleQuotedString("UTC".to_string()), + ))), + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(cast_timestamp_ntz_expr)), + ], + over: None, + distinct: false, + special: false, + order_by: Default::default(), + }); + + Ok(convert_tz_expr) + } + } +} + +/// Convert date_to_utc_timestamp(ts, tz) -> +/// function_name(ts, tz) +#[derive(Clone, Debug)] +pub struct DateToUtcTimestampWithFunctionTransformer { + function_name: String, +} + +impl DateToUtcTimestampWithFunctionTransformer { + pub fn new_dyn(function_name: &str) -> Arc { + Arc::new(Self { + function_name: function_name.to_string(), + }) + } +} + +impl FunctionTransformer for DateToUtcTimestampWithFunctionTransformer { + fn transform(&self, args: &[Expr], dialect: &Dialect, schema: &DFSchema) -> Result { + let (date_arg, time_zone) = process_date_to_utc_timestamp_args(args, dialect, schema)?; + + Ok(SqlExpr::Function(SqlFunction { + name: SqlObjectName(vec![SqlIdent { + value: self.function_name.clone(), + quote_style: None, + }]), + args: vec![ + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(date_arg)), + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(SqlExpr::Value( + SqlValue::SingleQuotedString(time_zone), + ))), + ], + over: None, + distinct: false, + special: false, + order_by: Default::default(), + })) + } +} + +/// Convert date_to_utc_timestamp(ts, tz) -> +/// function_name(CAST(ts as TIMESTAMP), tz) +/// or +/// function_name(CAST(ts as TIMESTAMP), tz) AT TIME ZONE 'UTC' +#[derive(Clone, Debug)] +pub struct DateToUtcTimestampWithCastFunctionAtTransformer { + timestamp_type: SqlDataType, + function_name: String, + at_timezone_utc: bool, +} + +impl DateToUtcTimestampWithCastFunctionAtTransformer { + pub fn new_dyn( + timestamp_type: SqlDataType, + function_name: &str, + at_timezone_utc: bool, + ) -> Arc { + Arc::new(Self { + timestamp_type, + function_name: function_name.to_string(), + at_timezone_utc, + }) + } +} + +impl FunctionTransformer for DateToUtcTimestampWithCastFunctionAtTransformer { + fn transform(&self, args: &[Expr], dialect: &Dialect, schema: &DFSchema) -> Result { + let (date_arg, time_zone) = process_date_to_utc_timestamp_args(args, dialect, schema)?; + + let cast_expr = SqlExpr::Cast { + expr: Box::new(date_arg), + data_type: self.timestamp_type.clone(), + }; + let fn_expr = SqlExpr::Function(SqlFunction { + name: SqlObjectName(vec![SqlIdent { + value: self.function_name.clone(), + quote_style: None, + }]), + args: vec![ + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(cast_expr)), + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(SqlExpr::Value( + SqlValue::SingleQuotedString(time_zone), + ))), + ], + over: None, + distinct: false, + special: false, + order_by: Default::default(), + }); + if self.at_timezone_utc { + Ok(SqlExpr::AtTimeZone { + timestamp: Box::new(fn_expr), + time_zone: "UTC".to_string(), + }) + } else { + Ok(fn_expr) + } + } +} + +/// Convert date_to_utc_timestamp(ts, tz) -> +/// toTimeZone(toDateTime(ts, tz), 'UTC') +#[derive(Clone, Debug)] +pub struct DateToUtcTimestampClickhouseTransformer; + +impl DateToUtcTimestampClickhouseTransformer { + pub fn new_dyn() -> Arc { + Arc::new(Self) + } +} + +impl FunctionTransformer for DateToUtcTimestampClickhouseTransformer { + fn transform(&self, args: &[Expr], dialect: &Dialect, schema: &DFSchema) -> Result { + let (date_arg, time_zone) = process_date_to_utc_timestamp_args(args, dialect, schema)?; + + let to_date_time_expr = SqlExpr::Function(SqlFunction { + name: SqlObjectName(vec![SqlIdent { + value: "toDateTime".to_string(), + quote_style: None, + }]), + args: vec![ + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(date_arg)), + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(SqlExpr::Value( + SqlValue::SingleQuotedString(time_zone), + ))), + ], + over: None, + distinct: false, + special: false, + order_by: Default::default(), + }); + + let to_time_zone_expr = SqlExpr::Function(SqlFunction { + name: SqlObjectName(vec![SqlIdent { + value: "toTimeZone".to_string(), + quote_style: None, + }]), + args: vec![ + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(to_date_time_expr)), + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(SqlExpr::Value( + SqlValue::SingleQuotedString("UTC".to_string()), + ))), + ], + over: None, + distinct: false, + special: false, + order_by: Default::default(), + }); + + Ok(to_time_zone_expr) + } +} + +/// Convert date_to_utc_timestamp(ts, tz) -> +/// convert_timezone(timestamp(ts), tz, 'UTC') +/// or if tz = 'UTC' +/// timestamp(ts) +#[derive(Clone, Debug)] +pub struct DateToUtcTimestampMySqlTransformer; + +impl DateToUtcTimestampMySqlTransformer { + pub fn new_dyn() -> Arc { + Arc::new(Self) + } +} + +impl FunctionTransformer for DateToUtcTimestampMySqlTransformer { + fn transform(&self, args: &[Expr], dialect: &Dialect, schema: &DFSchema) -> Result { + let (date_arg, time_zone) = process_date_to_utc_timestamp_args(args, dialect, schema)?; + + let timestamp_expr = SqlExpr::Function(SqlFunction { + name: SqlObjectName(vec![SqlIdent { + value: "timestamp".to_string(), + quote_style: None, + }]), + args: vec![SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(date_arg))], + over: None, + distinct: false, + special: false, + order_by: Default::default(), + }); + + if time_zone == "UTC" { + // No conversion needed + Ok(timestamp_expr) + } else { + let convert_tz_expr = SqlExpr::Function(SqlFunction { + name: SqlObjectName(vec![SqlIdent { + value: "convert_tz".to_string(), + quote_style: None, + }]), + args: vec![ + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(timestamp_expr)), + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(SqlExpr::Value( + SqlValue::SingleQuotedString(time_zone), + ))), + SqlFunctionArg::Unnamed(SqlFunctionArgExpr::Expr(SqlExpr::Value( + SqlValue::SingleQuotedString("UTC".to_string()), + ))), + ], + over: None, + distinct: false, + special: false, + order_by: Default::default(), + }); + + Ok(convert_tz_expr) + } + } +} diff --git a/vegafusion-sql/tests/expected/select.toml b/vegafusion-sql/tests/expected/select.toml index 0924d1373..cea3292d2 100644 --- a/vegafusion-sql/tests/expected/select.toml +++ b/vegafusion-sql/tests/expected/select.toml @@ -636,6 +636,48 @@ result = ''' +---+-------------------------+-------------------------+-------------------------+-------------------------+ ''' +[date_to_utc_timestamp] +athena = """ +WITH values0 AS (SELECT * FROM (VALUES (0, CAST('1998-12-20' AS DATE)), (1, CAST('2000-01-24' AS DATE)), (2, CAST('2000-02-13' AS DATE)), (3, CAST('2002-11-09' AS DATE))) AS "_values" ("a", "b")), values1 AS (SELECT "a", "b", CAST("b" AS TIMESTAMP) AT TIME ZONE 'America/New_York' AT TIME ZONE 'UTC' AS "b_utc" FROM values0) SELECT * FROM values1 ORDER BY "a" ASC NULLS FIRST +""" +bigquery = """ +WITH values0 AS (SELECT 0 AS `a`, CAST('1998-12-20' AS DATE) AS `b` UNION ALL SELECT 1 AS `a`, CAST('2000-01-24' AS DATE) AS `b` UNION ALL SELECT 2 AS `a`, CAST('2000-02-13' AS DATE) AS `b` UNION ALL SELECT 3 AS `a`, CAST('2002-11-09' AS DATE) AS `b`), values1 AS (SELECT `a`, `b`, timestamp(`b`, 'America/New_York') AS `b_utc` FROM values0) SELECT * FROM values1 ORDER BY `a` ASC NULLS FIRST +""" +clickhouse = """ +WITH values0 AS (SELECT 0 AS "a", CAST('1998-12-20' AS DATE) AS "b" UNION ALL SELECT 1 AS "a", CAST('2000-01-24' AS DATE) AS "b" UNION ALL SELECT 2 AS "a", CAST('2000-02-13' AS DATE) AS "b" UNION ALL SELECT 3 AS "a", CAST('2002-11-09' AS DATE) AS "b"), values1 AS (SELECT "a", "b", toTimeZone(toDateTime("b", 'America/New_York'), 'UTC') AS "b_utc" FROM values0) SELECT * FROM values1 ORDER BY "a" ASC NULLS FIRST +""" +databricks = """ +WITH values0 AS (SELECT * FROM (VALUES (0, CAST('1998-12-20' AS DATE)), (1, CAST('2000-01-24' AS DATE)), (2, CAST('2000-02-13' AS DATE)), (3, CAST('2002-11-09' AS DATE))) AS `_values` (`a`, `b`)), values1 AS (SELECT `a`, `b`, to_utc_timestamp(CAST(`b` AS TIMESTAMP), 'America/New_York') AS `b_utc` FROM values0) SELECT * FROM values1 ORDER BY `a` ASC NULLS FIRST +""" +datafusion = """ +WITH values0 AS (SELECT * FROM (VALUES (0, CAST('1998-12-20' AS DATE)), (1, CAST('2000-01-24' AS DATE)), (2, CAST('2000-02-13' AS DATE)), (3, CAST('2002-11-09' AS DATE))) AS "_values" ("a", "b")), values1 AS (SELECT "a", "b", date_to_utc_timestamp("b", 'America/New_York') AS "b_utc" FROM values0) SELECT * FROM values1 ORDER BY "a" ASC NULLS FIRST +""" +duckdb = """ +WITH values0 AS (SELECT * FROM (VALUES (0, CAST('1998-12-20' AS DATE)), (1, CAST('2000-01-24' AS DATE)), (2, CAST('2000-02-13' AS DATE)), (3, CAST('2002-11-09' AS DATE))) AS "_values" ("a", "b")), values1 AS (SELECT "a", "b", CAST("b" AS TIMESTAMP) AT TIME ZONE 'America/New_York' AT TIME ZONE 'UTC' AS "b_utc" FROM values0) SELECT * FROM values1 ORDER BY "a" ASC NULLS FIRST +""" +mysql = """ +WITH values0 AS (SELECT * FROM (VALUES ROW(0, CAST('1998-12-20' AS DATE)), ROW(1, CAST('2000-01-24' AS DATE)), ROW(2, CAST('2000-02-13' AS DATE)), ROW(3, CAST('2002-11-09' AS DATE))) AS `_values` (`a`, `b`)), values1 AS (SELECT `a`, `b`, convert_tz(timestamp(`b`), 'America/New_York', 'UTC') AS `b_utc` FROM values0) SELECT * FROM values1 ORDER BY `a` ASC +""" +postgres = """ +WITH values0 AS (SELECT * FROM (VALUES (0, CAST('1998-12-20' AS DATE)), (1, CAST('2000-01-24' AS DATE)), (2, CAST('2000-02-13' AS DATE)), (3, CAST('2002-11-09' AS DATE))) AS "_values" ("a", "b")), values1 AS (SELECT "a", "b", CAST("b" AS TIMESTAMP) AT TIME ZONE 'America/New_York' AT TIME ZONE 'UTC' AS "b_utc" FROM values0) SELECT * FROM values1 ORDER BY "a" ASC NULLS FIRST +""" +redshift = """ +WITH values0 AS (SELECT 0 AS "a", CAST('1998-12-20' AS DATE) AS "b" UNION ALL SELECT 1 AS "a", CAST('2000-01-24' AS DATE) AS "b" UNION ALL SELECT 2 AS "a", CAST('2000-02-13' AS DATE) AS "b" UNION ALL SELECT 3 AS "a", CAST('2002-11-09' AS DATE) AS "b"), values1 AS (SELECT "a", "b", CAST("b" AS TIMESTAMP) AT TIME ZONE 'America/New_York' AT TIME ZONE 'UTC' AS "b_utc" FROM values0) SELECT * FROM values1 ORDER BY "a" ASC NULLS FIRST +""" +snowflake = """ +WITH values0 AS (SELECT "COLUMN1" AS "a", "COLUMN2" AS "b" FROM (VALUES (0, CAST('1998-12-20' AS DATE)), (1, CAST('2000-01-24' AS DATE)), (2, CAST('2000-02-13' AS DATE)), (3, CAST('2002-11-09' AS DATE)))), values1 AS (SELECT "a", "b", convert_timezone('America/New_York', 'UTC', CAST("b" AS timestamp_ntz)) AS "b_utc" FROM values0) SELECT * FROM values1 ORDER BY "a" ASC NULLS FIRST +""" +result = ''' ++---+------------+---------------------+ +| a | b | b_utc | ++---+------------+---------------------+ +| 0 | 1998-12-20 | 1998-12-20T05:00:00 | +| 1 | 2000-01-24 | 2000-01-24T05:00:00 | +| 2 | 2000-02-13 | 2000-02-13T05:00:00 | +| 3 | 2002-11-09 | 2002-11-09T05:00:00 | ++---+------------+---------------------+ +''' + [test_string_ops] athena = """ WITH values0 AS (SELECT * FROM (VALUES (0, '1234', 'efGH'), (1, 'abCD', '5678'), (3, NULL, NULL)) AS "_values" ("a", "b", "c")), values1 AS (SELECT "a", "b", "c", substr("b", 2, 2) AS "b_substr", concat("b", ' ', "c") AS "bc_concat", upper("b") AS "b_upper", lower("b") AS "b_lower" FROM values0) SELECT * FROM values1 ORDER BY "a" ASC NULLS FIRST diff --git a/vegafusion-sql/tests/test_select.rs b/vegafusion-sql/tests/test_select.rs index 433aa0e3a..fc52b6c6f 100644 --- a/vegafusion-sql/tests/test_select.rs +++ b/vegafusion-sql/tests/test_select.rs @@ -1177,6 +1177,79 @@ mod test_utc_timestamp_to_str { fn test_marker() {} // Help IDE detect test module } +#[cfg(test)] +mod test_date_to_utc_timestamp { + use crate::*; + use arrow::array::{ArrayRef, Date32Array, Int32Array}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use arrow::record_batch::RecordBatch; + use datafusion_expr::{expr, lit, Expr}; + use std::sync::Arc; + use vegafusion_common::column::flat_col; + use vegafusion_datafusion_udfs::udfs::datetime::date_to_utc_timestamp::DATE_TO_UTC_TIMESTAMP_UDF; + + #[apply(dialect_names)] + async fn test(dialect_name: &str) { + println!("{dialect_name}"); + let (conn, evaluable) = TOKIO_RUNTIME.block_on(make_connection(dialect_name)); + + let schema_ref: SchemaRef = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Date32, true), + ])); + let columns = vec![ + Arc::new(Int32Array::from(vec![0, 1, 2, 3])) as ArrayRef, + Arc::new(Date32Array::from(vec![ + 10580, // 1998-12-20 + 10980, // 2000-01-24 + 11000, // 2000-02-13 + 12000, // 2002-11-09 + ])) as ArrayRef, + ]; + + let batch = RecordBatch::try_new(schema_ref.clone(), columns).unwrap(); + let table = VegaFusionTable::try_new(schema_ref, vec![batch]).unwrap(); + let df = SqlDataFrame::from_values(&table, conn, Default::default()).unwrap(); + + let df_result = df + .select(vec![ + flat_col("a"), + flat_col("b"), + Expr::ScalarUDF(expr::ScalarUDF { + fun: Arc::new(DATE_TO_UTC_TIMESTAMP_UDF.clone()), + args: vec![flat_col("b"), lit("America/New_York")], + }) + .alias("b_utc"), + ]) + .await; + + let df_result = if let Ok(df) = df_result { + df.sort( + vec![Expr::Sort(expr::Sort { + expr: Box::new(flat_col("a")), + asc: true, + nulls_first: true, + })], + None, + ) + .await + } else { + df_result + }; + + check_dataframe_query( + df_result, + "select", + "date_to_utc_timestamp", + dialect_name, + evaluable, + ); + } + + #[test] + fn test_marker() {} // Help IDE detect test module +} + #[cfg(test)] mod test_string_ops { use crate::*; diff --git a/vegafusion-sql/tests/utils/mod.rs b/vegafusion-sql/tests/utils/mod.rs index 0a3cb8b4c..3d165262f 100644 --- a/vegafusion-sql/tests/utils/mod.rs +++ b/vegafusion-sql/tests/utils/mod.rs @@ -88,7 +88,7 @@ pub fn check_dataframe_query( println!("Unsupported"); return; } else { - panic!("Expected sort result to be an error") + panic!("Expected query result to be an error") } } let df = df_result.unwrap();