Skip to content

Commit

Permalink
Add SnowparkDataset and date_to_utc_timestamp support across dialects (
Browse files Browse the repository at this point in the history
…#374)

* Update SqlDataFrame to handle source tables with quotes and dots like "DEMO_DATA"."DEMOS"."MOVIES"

* Add date_to_utc_timestamp support across dialects

* Add SnowparkDataset
  • Loading branch information
jonmmease authored Aug 11, 2023
1 parent 4518344 commit c01850d
Show file tree
Hide file tree
Showing 10 changed files with 572 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

101 changes: 101 additions & 0 deletions python/vegafusion/vegafusion/dataset/snowpark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import logging
import pyarrow as pa
from .sql import SqlDataset
from snowflake.snowpark import Table as SnowparkTable
from snowflake.snowpark.types import DataType as SnowparkDataType
from typing import Dict

from ..transformer import to_arrow_table

SNOWPARK_TO_PYARROW_TYPES: Dict[SnowparkDataType, pa.DataType] = {}


def get_snowpark_to_pyarrow_types():
if not SNOWPARK_TO_PYARROW_TYPES:
import snowflake.snowpark.types as sp_types

SNOWPARK_TO_PYARROW_TYPES.update(
{
sp_types.LongType: pa.int64(),
sp_types.BinaryType: pa.binary(),
sp_types.BooleanType: pa.bool_(),
sp_types.ByteType: pa.int8(),
sp_types.StringType: pa.string(),
sp_types.DateType: pa.date32(),
sp_types.DoubleType: pa.float64(),
sp_types.FloatType: pa.float32(),
sp_types.IntegerType: pa.int32(),
sp_types.ShortType: pa.int16(),
sp_types.TimestampType: pa.timestamp("ms"),
}
)
return SNOWPARK_TO_PYARROW_TYPES


def snowflake_field_to_pyarrow_type(provided_type: SnowparkDataType) -> pa.DataType:
"""
Converts Snowflake types to PyArrow equivalent types, raising a ValueError if they aren't comparable.
See https://docs.snowflake.com/en/sql-reference/intro-summary-data-types
"""
from snowflake.snowpark.types import DecimalType as SnowparkDecimalType

type_map = get_snowpark_to_pyarrow_types()
if provided_type.__class__ in type_map:
return type_map[provided_type.__class__]

if isinstance(provided_type, SnowparkDecimalType):
return pa.decimal128(provided_type.precision, provided_type.scale)
else:
raise ValueError(f"Unsupported Snowpark type: {provided_type}")


def snowpark_table_to_pyarrow_schema(table: SnowparkTable) -> pa.Schema:
schema_fields = {}
for name, field in zip(table.schema.names, table.schema.fields):
normalised_name = name.strip('"')
schema_fields[normalised_name] = snowflake_field_to_pyarrow_type(field.datatype)
return pa.schema(schema_fields)


class SnowparkDataset(SqlDataset):
def dialect(self) -> str:
return "snowflake"

def __init__(
self, table: SnowparkTable, fallback: bool = True, verbose: bool = False
):
if not isinstance(table, SnowparkTable):
raise ValueError(
f"SnowparkDataset accepts a snowpark Table. Received: {type(table)}"
)
self._table = table
self._session = table._session

self._fallback = fallback
self._verbose = verbose
self._table_name = table.table_name
self._table_schema = snowpark_table_to_pyarrow_schema(self._table)

self.logger = logging.getLogger("SnowparkDataset")

def table_name(self) -> str:
return self._table_name

def table_schema(self) -> pa.Schema:
return self._table_schema

def fetch_query(self, query: str, schema: pa.Schema) -> pa.Table:
self.logger.info(f"Snowflake Query:\n{query}\n")
if self._verbose:
print(f"Snowflake Query:\n{query}\n")

sp_df = self._session.sql(query)
batches = []
for pd_batch in sp_df.to_pandas_batches():
pa_tbl = to_arrow_table(pd_batch).cast(schema)
batches.extend(pa_tbl.to_batches())

return pa.Table.from_batches(batches, schema)

def fallback(self) -> bool:
return self._fallback
1 change: 1 addition & 0 deletions vegafusion-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ datafusion-conn = [ "datafusion", "tempfile", "reqwest", "reqwest-retry", "reqwe
async-trait = "0.1.53"
deterministic-hash = "1.0.1"
log = "0.4.17"
chrono = "0.4.23"

[dev-dependencies]
rstest = "0.17.0"
Expand Down
24 changes: 21 additions & 3 deletions vegafusion-sql/src/compile/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::compile::data_type::ToSqlDataType;
use crate::dialect::Dialect;
use arrow::datatypes::DataType;
use datafusion_common::scalar::ScalarValue;
use sqlparser::ast::{
Expr as SqlExpr, Function as SqlFunction, FunctionArg as SqlFunctionArg, FunctionArgExpr,
Ident, ObjectName as SqlObjectName, Value as SqlValue,
};
use std::ops::Add;
use vegafusion_common::error::{Result, VegaFusionError};

pub trait ToSqlScalar {
Expand Down Expand Up @@ -159,9 +161,7 @@ impl ToSqlScalar for ScalarValue {
order_by: Default::default(),
}))
}
ScalarValue::Date32(_) => Err(VegaFusionError::internal(
"Date32 cannot be converted to SQL",
)),
ScalarValue::Date32(v) => date32_to_date(v, dialect),
ScalarValue::Date64(_) => Err(VegaFusionError::internal(
"Date64 cannot be converted to SQL",
)),
Expand Down Expand Up @@ -246,3 +246,21 @@ fn ms_to_timestamp(v: i64) -> SqlExpr {
order_by: Default::default(),
})
}

fn date32_to_date(days: &Option<i32>, dialect: &Dialect) -> Result<SqlExpr> {
let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
match days {
None => Ok(SqlExpr::Cast {
expr: Box::new(ScalarValue::Utf8(None).to_sql(dialect)?),
data_type: DataType::Date32.to_sql(dialect)?,
}),
Some(days) => {
let date = epoch.add(chrono::Duration::days(*days as i64));
let date_str = date.format("%F").to_string();
Ok(SqlExpr::Cast {
expr: Box::new(ScalarValue::from(date_str.as_str()).to_sql(dialect)?),
data_type: DataType::Date32.to_sql(dialect)?,
})
}
}
}
17 changes: 15 additions & 2 deletions vegafusion-sql/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,28 @@ impl SqlDataFrame {
.collect();
let select_items = columns.join(", ");

let table_ident = Ident::with_quote(conn.dialect().quote_style, table).to_string();
// Replace special characters with underscores
let mut clean_table = table.to_string();
for c in &['"', '\'', '.', '-'] {
clean_table = clean_table.replace(*c, "_");
}

let quote_style = conn.dialect().quote_style;
let table_ident = if !table.starts_with(quote_style) {
// Quote table
Ident::with_quote(conn.dialect().quote_style, table).to_string()
} else {
// If table name starts with the quote character, assume already quoted
table.to_string()
};

let query = parse_sql_query(
&format!("select {select_items} from {table_ident}"),
conn.dialect(),
)?;

Ok(Self {
prefix: format!("{table}_"),
prefix: format!("{clean_table}_"),
ctes: vec![query],
schema: Arc::new(schema),
conn,
Expand Down
48 changes: 47 additions & 1 deletion vegafusion-sql/src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ use crate::dialect::transforms::date_part_tz::{
DatePartTzWithDatePartAndAtTimezoneTransformer, DatePartTzWithExtractAndAtTimezoneTransformer,
DatePartTzWithFromUtcAndDatePartTransformer,
};
use crate::dialect::transforms::date_to_utc_timestamp::DateToUtcTimestampWithCastAndAtTimeZoneTransformer;
use crate::dialect::transforms::date_to_utc_timestamp::{
DateToUtcTimestampClickhouseTransformer, DateToUtcTimestampMySqlTransformer,
DateToUtcTimestampSnowflakeTransform, DateToUtcTimestampWithCastAndAtTimeZoneTransformer,
DateToUtcTimestampWithCastFunctionAtTransformer, DateToUtcTimestampWithFunctionTransformer,
};
use crate::dialect::transforms::date_trunc_tz::{
DateTruncTzClickhouseTransformer, DateTruncTzSnowflakeTransformer,
DateTruncTzWithDateTruncAndAtTimezoneTransformer,
Expand Down Expand Up @@ -321,6 +325,10 @@ impl Dialect {
"date_trunc_tz",
DateTruncTzWithDateTruncAndAtTimezoneTransformer::new_dyn(false),
),
(
"date_to_utc_timestamp",
DateToUtcTimestampWithCastAndAtTimeZoneTransformer::new_dyn(),
),
]
.into_iter()
.map(|(name, v)| (name.to_string(), v))
Expand Down Expand Up @@ -353,6 +361,7 @@ impl Dialect {
DataType::Timestamp(TimeUnit::Millisecond, None),
SqlDataType::Timestamp(None, TimezoneInfo::None),
),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down Expand Up @@ -450,6 +459,10 @@ impl Dialect {
"utc_timestamp_to_str",
UtcTimestampToStrBigQueryTransformer::new_dyn(),
),
(
"date_to_utc_timestamp",
DateToUtcTimestampWithFunctionTransformer::new_dyn("timestamp"),
),
]
.into_iter()
.map(|(name, v)| (name.to_string(), v))
Expand Down Expand Up @@ -480,6 +493,7 @@ impl Dialect {
DataType::Timestamp(TimeUnit::Millisecond, None),
SqlDataType::Timestamp(None, TimezoneInfo::None),
),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down Expand Up @@ -550,6 +564,10 @@ impl Dialect {
),
("date_part_tz", DatePartTzClickhouseTransformer::new_dyn()),
("date_trunc_tz", DateTruncTzClickhouseTransformer::new_dyn()),
(
"date_to_utc_timestamp",
DateToUtcTimestampClickhouseTransformer::new_dyn(),
),
]
.into_iter()
.map(|(name, v)| (name.to_string(), v))
Expand Down Expand Up @@ -580,6 +598,7 @@ impl Dialect {
DataType::Timestamp(TimeUnit::Millisecond, None),
SqlDataType::Timestamp(None, TimezoneInfo::None),
),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down Expand Up @@ -691,6 +710,14 @@ impl Dialect {
"utc_timestamp_to_str",
UtcTimestampToStrDatabricksTransformer::new_dyn(),
),
(
"date_to_utc_timestamp",
DateToUtcTimestampWithCastFunctionAtTransformer::new_dyn(
SqlDataType::Timestamp(None, TimezoneInfo::None),
"to_utc_timestamp",
false,
),
),
]
.into_iter()
.map(|(name, v)| (name.to_string(), v))
Expand Down Expand Up @@ -723,6 +750,7 @@ impl Dialect {
DataType::Timestamp(TimeUnit::Millisecond, None),
SqlDataType::Timestamp(None, TimezoneInfo::None),
),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down Expand Up @@ -887,6 +915,7 @@ impl Dialect {
DataType::Timestamp(TimeUnit::Millisecond, None),
SqlDataType::Timestamp(None, TimezoneInfo::None),
),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down Expand Up @@ -1042,6 +1071,7 @@ impl Dialect {
DataType::Timestamp(TimeUnit::Millisecond, None),
SqlDataType::Timestamp(None, TimezoneInfo::None),
),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down Expand Up @@ -1113,6 +1143,10 @@ impl Dialect {
StrToUtcTimestampMySqlTransformer::new_dyn(),
),
("date_part_tz", DatePartTzMySqlTransformer::new_dyn()),
(
"date_to_utc_timestamp",
DateToUtcTimestampMySqlTransformer::new_dyn(),
),
]
.into_iter()
.map(|(name, v)| (name.to_string(), v))
Expand All @@ -1138,6 +1172,7 @@ impl Dialect {
(DataType::Float32, SqlDataType::Float(None)),
(DataType::Float64, SqlDataType::Double),
(DataType::Utf8, SqlDataType::Char(None)),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down Expand Up @@ -1294,6 +1329,7 @@ impl Dialect {
DataType::Timestamp(TimeUnit::Millisecond, None),
SqlDataType::Timestamp(None, TimezoneInfo::None),
),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down Expand Up @@ -1396,6 +1432,10 @@ impl Dialect {
"date_trunc_tz",
DateTruncTzWithDateTruncAndAtTimezoneTransformer::new_dyn(true),
),
(
"date_to_utc_timestamp",
DateToUtcTimestampWithCastAndAtTimeZoneTransformer::new_dyn(),
),
]
.into_iter()
.map(|(name, v)| (name.to_string(), v))
Expand Down Expand Up @@ -1426,6 +1466,7 @@ impl Dialect {
DataType::Timestamp(TimeUnit::Millisecond, None),
SqlDataType::Timestamp(None, TimezoneInfo::None),
),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down Expand Up @@ -1539,6 +1580,10 @@ impl Dialect {
"utc_timestamp_to_str",
UtcTimestampToStrSnowflakeTransformer::new_dyn(),
),
(
"date_to_utc_timestamp",
DateToUtcTimestampSnowflakeTransform::new_dyn(),
),
]
.into_iter()
.map(|(name, v)| (name.to_string(), v))
Expand Down Expand Up @@ -1573,6 +1618,7 @@ impl Dialect {
DataType::Timestamp(TimeUnit::Millisecond, None),
SqlDataType::Timestamp(None, TimezoneInfo::None),
),
(DataType::Date32, SqlDataType::Date),
]
.into_iter()
.collect(),
Expand Down
Loading

0 comments on commit c01850d

Please sign in to comment.