Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-145] Support decimal in columnar window #151

Merged
merged 6 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/src/main/scala/com/intel/oap/ColumnarGuardRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,12 @@ case class ColumnarGuardRule(conf: SparkConf) extends Rule[SparkPlan] {
plan.isSkewJoin)
case plan: WindowExec =>
if (!enableColumnarWindow) return false
new ColumnarWindowExec(
val window = ColumnarWindowExec.create(
plan.windowExpression,
plan.partitionSpec,
plan.orderSpec,
plan.child)
window
case p =>
p
}
Expand Down
19 changes: 12 additions & 7 deletions core/src/main/scala/com/intel/oap/ColumnarPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,17 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {
new ColumnarBatchScanExec(plan.output, plan.scan)
}
case plan: ProjectExec =>
val columnarPlan = replaceWithColumnarPlan(plan.child)
val columnarChild = replaceWithColumnarPlan(plan.child)
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
if (columnarPlan.isInstanceOf[ColumnarConditionProjectExec]) {
val cur_plan = columnarPlan.asInstanceOf[ColumnarConditionProjectExec]
ColumnarConditionProjectExec(cur_plan.condition, plan.projectList, cur_plan.child)
} else {
ColumnarConditionProjectExec(null, plan.projectList, columnarPlan)
columnarChild match {
case ch: ColumnarConditionProjectExec =>
if (ch.projectList == null) {
ColumnarConditionProjectExec(ch.condition, plan.projectList, ch.child)
} else {
ColumnarConditionProjectExec(null, plan.projectList, columnarChild)
}
case _ =>
ColumnarConditionProjectExec(null, plan.projectList, columnarChild)
}
case plan: FilterExec =>
val child = replaceWithColumnarPlan(plan.child)
Expand Down Expand Up @@ -234,11 +238,12 @@ case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {
}
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
try {
return new ColumnarWindowExec(
val window = ColumnarWindowExec.create(
plan.windowExpression,
plan.partitionSpec,
plan.orderSpec,
coalesceBatchRemoved)
return window
} catch {
case _: Throwable =>
logInfo("Columnar Window: Falling back to regular Window...")
Expand Down
110 changes: 107 additions & 3 deletions core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,21 @@ import org.apache.arrow.gandiva.expression.TreeBuilder
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, NamedExpression, Rank, SortOrder, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Sum}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, MakeDecimal, NamedExpression, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, Sum}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, DoubleType, LongType}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ExecutorManager

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.util.Random

class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
Expand Down Expand Up @@ -248,3 +249,106 @@ class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
override def isComplex: Boolean = false
}
}

object ColumnarWindowExec {

def createWithProjection(
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan): SparkPlan = {

def makeInputProject(ex: Expression, inputProjects: ListBuffer[NamedExpression]): Expression = {
ex match {
case ae: AggregateExpression => ae.withNewChildren(ae.children.map(makeInputProject(_, inputProjects)))
case ae: WindowExpression => ae.withNewChildren(ae.children.map(makeInputProject(_, inputProjects)))
case func @ (_: AggregateFunction | _: WindowFunction) =>
val params = func.children
// rewrite
val rewritten = func match {
case _: Average =>
// rewrite params for AVG
params.map {
param =>
param.dataType match {
case _: LongType | _: DecimalType =>
Cast(param, DoubleType)
case _ => param
}
}
case _ => params
}

// alias
func.withNewChildren(rewritten.map {
case param @ (_: Cast | _: UnscaledValue) =>
val aliasName = "__alias_%d__".format(Random.nextLong())
val alias = Alias(param, aliasName)()
inputProjects.append(alias)
alias.toAttribute
case other => other
})
case other => other
}
}

def sameType(from: DataType, to: DataType): Boolean = {
if (from == null || to == null) {
throw new IllegalArgumentException("null type found during type enforcement")
}
if (from == to) {
return true
}
DataType.equalsStructurally(from, to)
}

def makeOutputProject(ex: Expression, windows: ListBuffer[NamedExpression], inputProjects: ListBuffer[NamedExpression]): Expression = {
val out = ex match {
case we: WindowExpression =>
val aliasName = "__alias_%d__".format(Random.nextLong())
val alias = Alias(makeInputProject(we, inputProjects), aliasName)()
windows.append(alias)
alias.toAttribute
case _ =>
ex.withNewChildren(ex.children.map(makeOutputProject(_, windows, inputProjects)))
}
// forcibly cast to original type against possible rewriting
val casted = try {
if (sameType(out.dataType, ex.dataType)) {
out
} else {
Cast(out, ex.dataType)
}
} catch {
case t: Throwable =>
System.err.println("Warning: " + t.getMessage)
Cast(out, ex.dataType)
}
casted
}

val windows = ListBuffer[NamedExpression]()
val inProjectExpressions = ListBuffer[NamedExpression]()
val outProjectExpressions = windowExpression.map(e => e.asInstanceOf[Alias])
.map { a =>
a.withNewChildren(List(makeOutputProject(a.child, windows, inProjectExpressions)))
.asInstanceOf[NamedExpression]
}

val inputProject = ColumnarConditionProjectExec(null, child.output ++ inProjectExpressions, child)

val window = new ColumnarWindowExec(windows, partitionSpec, orderSpec, inputProject)

val outputProject = ColumnarConditionProjectExec(null, child.output ++ outProjectExpressions, window)

outputProject
}

def create(
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan): SparkPlan = {
createWithProjection(windowExpression, partitionSpec, orderSpec, child)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ SELECT
ca_country,
ca_state,
ca_county,
avg(cast(cs_quantity AS DOUBLE)) agg1,
avg(cast(cs_list_price AS DOUBLE)) agg2,
avg(cast(cs_coupon_amt AS DOUBLE)) agg3,
avg(cast(cs_sales_price AS DOUBLE)) agg4,
avg(cast(cs_net_profit AS DOUBLE)) agg5,
avg(cast(c_birth_year AS DOUBLE)) agg6,
avg(cast(cd1.cd_dep_count AS DOUBLE)) agg7
avg(cast(cs_quantity AS DECIMAL(12, 2))) agg1,
avg(cast(cs_list_price AS DECIMAL(12, 2))) agg2,
avg(cast(cs_coupon_amt AS DECIMAL(12, 2))) agg3,
avg(cast(cs_sales_price AS DECIMAL(12, 2))) agg4,
avg(cast(cs_net_profit AS DECIMAL(12, 2))) agg5,
avg(cast(c_birth_year AS DECIMAL(12, 2))) agg6,
avg(cast(cd1.cd_dep_count AS DECIMAL(12, 2))) agg7
FROM catalog_sales, customer_demographics cd1,
customer_demographics cd2, customer, customer_address, date_dim, item
WHERE cs_sold_date_sk = d_date_sk AND
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ FROM (
FROM
(SELECT
ws.ws_item_sk AS item,
(cast(sum(coalesce(wr.wr_return_quantity, 0)) AS DOUBLE) /
cast(sum(coalesce(ws.ws_quantity, 0)) AS DOUBLE)) AS return_ratio,
(cast(sum(coalesce(wr.wr_return_amt, 0)) AS DOUBLE) /
cast(sum(coalesce(ws.ws_net_paid, 0)) AS DOUBLE)) AS currency_ratio
(cast(sum(coalesce(wr.wr_return_quantity, 0)) AS DECIMAL(15, 4)) /
cast(sum(coalesce(ws.ws_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio,
(cast(sum(coalesce(wr.wr_return_amt, 0)) AS DECIMAL(15, 4)) /
cast(sum(coalesce(ws.ws_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio
FROM
web_sales ws LEFT OUTER JOIN web_returns wr
ON (ws.ws_order_number = wr.wr_order_number AND
Expand Down Expand Up @@ -60,10 +60,10 @@ FROM (
FROM
(SELECT
cs.cs_item_sk AS item,
(cast(sum(coalesce(cr.cr_return_quantity, 0)) AS DOUBLE) /
cast(sum(coalesce(cs.cs_quantity, 0)) AS DOUBLE)) AS return_ratio,
(cast(sum(coalesce(cr.cr_return_amount, 0)) AS DOUBLE) /
cast(sum(coalesce(cs.cs_net_paid, 0)) AS DOUBLE)) AS currency_ratio
(cast(sum(coalesce(cr.cr_return_quantity, 0)) AS DECIMAL(15, 4)) /
cast(sum(coalesce(cs.cs_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio,
(cast(sum(coalesce(cr.cr_return_amount, 0)) AS DECIMAL(15, 4)) /
cast(sum(coalesce(cs.cs_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio
FROM
catalog_sales cs LEFT OUTER JOIN catalog_returns cr
ON (cs.cs_order_number = cr.cr_order_number AND
Expand Down Expand Up @@ -102,10 +102,10 @@ FROM (
FROM
(SELECT
sts.ss_item_sk AS item,
(cast(sum(coalesce(sr.sr_return_quantity, 0)) AS DOUBLE) /
cast(sum(coalesce(sts.ss_quantity, 0)) AS DOUBLE)) AS return_ratio,
(cast(sum(coalesce(sr.sr_return_amt, 0)) AS DOUBLE) /
cast(sum(coalesce(sts.ss_net_paid, 0)) AS DOUBLE)) AS currency_ratio
(cast(sum(coalesce(sr.sr_return_quantity, 0)) AS DECIMAL(15, 4)) /
cast(sum(coalesce(sts.ss_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio,
(cast(sum(coalesce(sr.sr_return_amt, 0)) AS DECIMAL(15, 4)) /
cast(sum(coalesce(sts.ss_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio
FROM
store_sales sts LEFT OUTER JOIN store_returns sr
ON (sts.ss_ticket_number = sr.sr_ticket_number AND sts.ss_item_sk = sr.sr_item_sk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ WITH ssr AS
ss_sold_date_sk AS date_sk,
ss_ext_sales_price AS sales_price,
ss_net_profit AS profit,
cast(0 AS DOUBLE) AS return_amt,
cast(0 AS DOUBLE) AS net_loss
cast(0 AS DECIMAL(7, 2)) AS return_amt,
cast(0 AS DECIMAL(7, 2)) AS net_loss
FROM store_sales
UNION ALL
SELECT
sr_store_sk AS store_sk,
sr_returned_date_sk AS date_sk,
cast(0 AS DOUBLE) AS sales_price,
cast(0 AS DOUBLE) AS profit,
cast(0 AS DECIMAL(7, 2)) AS sales_price,
cast(0 AS DECIMAL(7, 2)) AS profit,
sr_return_amt AS return_amt,
sr_net_loss AS net_loss
FROM store_returns)
Expand All @@ -42,15 +42,15 @@ WITH ssr AS
cs_sold_date_sk AS date_sk,
cs_ext_sales_price AS sales_price,
cs_net_profit AS profit,
cast(0 AS DOUBLE) AS return_amt,
cast(0 AS DOUBLE) AS net_loss
cast(0 AS DECIMAL(7, 2)) AS return_amt,
cast(0 AS DECIMAL(7, 2)) AS net_loss
FROM catalog_sales
UNION ALL
SELECT
cr_catalog_page_sk AS page_sk,
cr_returned_date_sk AS date_sk,
cast(0 AS DOUBLE) AS sales_price,
cast(0 AS DOUBLE) AS profit,
cast(0 AS DECIMAL(7, 2)) AS sales_price,
cast(0 AS DECIMAL(7, 2)) AS profit,
cr_return_amount AS return_amt,
cr_net_loss AS net_loss
FROM catalog_returns
Expand All @@ -74,15 +74,15 @@ WITH ssr AS
ws_sold_date_sk AS date_sk,
ws_ext_sales_price AS sales_price,
ws_net_profit AS profit,
cast(0 AS DOUBLE) AS return_amt,
cast(0 AS DOUBLE) AS net_loss
cast(0 AS DECIMAL(7, 2)) AS return_amt,
cast(0 AS DECIMAL(7, 2)) AS net_loss
FROM web_sales
UNION ALL
SELECT
ws_web_site_sk AS wsr_web_site_sk,
wr_returned_date_sk AS date_sk,
cast(0 AS DOUBLE) AS sales_price,
cast(0 AS DOUBLE) AS profit,
cast(0 AS DECIMAL(7, 2)) AS sales_price,
cast(0 AS DECIMAL(7, 2)) AS profit,
wr_return_amt AS return_amt,
wr_net_loss AS net_loss
FROM web_returns
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
SELECT
promotions,
total,
cast(promotions AS DOUBLE) / cast(total AS DOUBLE) * 100
cast(promotions AS DECIMAL(15, 4)) / cast(total AS DECIMAL(15, 4)) * 100
FROM
(SELECT sum(ss_ext_sales_price) promotions
FROM store_sales, store, promotion, date_dim, customer, customer_address, item
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ WHERE curr_yr.i_brand_id = prev_yr.i_brand_id
AND curr_yr.i_manufact_id = prev_yr.i_manufact_id
AND curr_yr.d_year = 2002
AND prev_yr.d_year = 2002 - 1
AND CAST(curr_yr.sales_cnt AS DOUBLE) / CAST(prev_yr.sales_cnt AS DOUBLE) < 0.9
AND CAST(curr_yr.sales_cnt AS DECIMAL(17, 2)) / CAST(prev_yr.sales_cnt AS DECIMAL(17, 2)) < 0.9
ORDER BY sales_cnt_diff
LIMIT 100
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SELECT cast(amc AS DOUBLE) / cast(pmc AS DOUBLE) am_pm_ratio
SELECT cast(amc AS DECIMAL(15, 4)) / cast(pmc AS DECIMAL(15, 4)) am_pm_ratio
FROM (SELECT count(*) amc
FROM web_sales, household_demographics, time_dim, web_page
WHERE ws_sold_time_sk = time_dim.t_time_sk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ where
)
and ps_availqty > (
select
0.5 * sum(CAST(l_quantity AS DOUBLE))
0.5 * sum(l_quantity)
from
lineitem
where
Expand Down
30 changes: 29 additions & 1 deletion core/src/test/scala/com/intel/oap/tpc/ds/TPCDSSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.test.SharedSparkSession
class TPCDSSuite extends QueryTest with SharedSparkSession {

private val MAX_DIRECT_MEMORY = "6g"
private val TPCDS_QUERIES_RESOURCE = "tpcds-queries-double"
private val TPCDS_QUERIES_RESOURCE = "tpcds-queries"
private val TPCDS_WRITE_PATH = "/tmp/tpcds-generated"

private var runner: TPCRunner = _
Expand Down Expand Up @@ -88,6 +88,34 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
test("window query") {
runner.runTPCQuery("q67", 1, true)
}

test("window function with non-decimal input") {
val df = spark.sql("SELECT i_item_sk, i_class_id, SUM(i_category_id)" +
" OVER (PARTITION BY i_class_id) FROM item LIMIT 1000")
df.explain()
df.show()
}

test("window function with decimal input") {
val df = spark.sql("SELECT i_item_sk, i_class_id, SUM(i_current_price)" +
" OVER (PARTITION BY i_class_id) FROM item LIMIT 1000")
df.explain()
df.show()
}

test("window function with decimal input 2") {
val df = spark.sql("SELECT i_item_sk, i_class_id, RANK()" +
" OVER (PARTITION BY i_class_id ORDER BY i_current_price) FROM item LIMIT 1000")
df.explain()
df.show()
}

test("window function with decimal input 3") {
val df = spark.sql("SELECT i_item_sk, i_class_id, AVG(i_current_price)" +
" OVER (PARTITION BY i_class_id) FROM item LIMIT 1000")
df.explain()
df.show()
}
}

object TPCDSSuite {
Expand Down
Loading