Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33427][SQL] Add subexpression elimination for interpreted expression evaluation #30341

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions

import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf

/**
* This class helps subexpression elimination for interpreted evaluation
* in `InterpretedUnsafeProjection`. It maintains an evaluation cache.
* This class wraps `ExpressionProxy` around given expressions. The `ExpressionProxy`
* intercepts expression evaluation and loads from the cache first.
*/
class EvaluationRunTime {
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved

val cache: LoadingCache[ExpressionProxy, ResultProxy] = CacheBuilder.newBuilder()
.maximumSize(SQLConf.get.subexpressionEliminationCacheMaxEntries)
.build(
new CacheLoader[ExpressionProxy, ResultProxy]() {
override def load(expr: ExpressionProxy): ResultProxy = {
ResultProxy(expr.proxyEval(currentInput))
}
})

private var currentInput: InternalRow = null

/**
* Sets given input row as current row for evaluating expressions. This cleans up the cache
* too as new input comes.
*/
def setInput(input: InternalRow = null): Unit = {
currentInput = input
cache.invalidateAll()
}

/**
* Finds subexpressions and wraps them with `ExpressionProxy`.
*/
def proxyExpressions(expressions: Seq[Expression]): Seq[Expression] = {
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions

expressions.foreach(equivalentExpressions.addExprTree(_))

var proxyMap = Map.empty[Expression, ExpressionProxy]

val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
commonExprs.foreach { e =>
val expr = e.head
val proxy = ExpressionProxy(expr, this)

proxyMap ++= e.map(_ -> proxy).toMap
}

// Only adding proxy if we find subexpressions.
if (proxyMap.nonEmpty) {
expressions.map { expr =>
// `transform` will cause stackoverflow because it keeps transforming into
// `ExpressionProxy`. But we cannot use `transformUp` because we want to use
// subexpressions at higher level. So we `transformDown` until finding first
// subexpression.
var transformed = false
expr.transform {
case e if !transformed && proxyMap.contains(e) =>
transformed = true
proxyMap(e)
}
}
} else {
expressions
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions

import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.DataType

/**
* A proxy for an catalyst `Expression`. Given a runtime object `EvaluationRunTime`, when this
* is asked to evaluate, it will load the evaluation cache in the runtime first.
*/
case class ExpressionProxy(child: Expression, runtime: EvaluationRunTime) extends Expression {

final override def dataType: DataType = child.dataType
final override def nullable: Boolean = child.nullable
final override def children: Seq[Expression] = child :: Nil

// `ExpressionProxy` is for interpreted expression evaluation only. So cannot `doGenCode`.
final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw new UnsupportedOperationException(s"Cannot generate code for expression: $this")

def proxyEval(input: InternalRow = null): Any = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this doesn't need a default value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

child.eval(input)
}

override def eval(input: InternalRow = null): Any = try {
runtime.cache.get(this).result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does cache lookup include expression canonilization?

Copy link
Member Author

@viirya viirya Nov 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, here the semantically equal expressions are linked to same ExpressionProxy. The cache only uses ExpressionProxy as keys.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assign unique IDs(per EvaluationRuntime) to ExpressionProxy? then the lookup can be more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, but in order to look up using unique ID of ExpressionProxy, we need to keep a map (id -> ExpressionProxy) inside EvaluationRuntime, so we know to call which proxy's proxyEval in CacheLoader.load function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be fine but want to make sure I get your point correctly.

} catch {
// Cache.get() may wrap the original exception. See the following URL
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
// http://google.github.io/guava/releases/14.0/api/docs/com/google/common/cache/
// Cache.html#get(K,%20java.util.concurrent.Callable)
case e@(_: UncheckedExecutionException | _: ExecutionError) =>
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
throw e.getCause
}
}

/**
* A simple wrapper for holding `Any` in the cache of `EvaluationRunTime`.
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
*/
case class ResultProxy(result: Any)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{UserDefinedType, _}
import org.apache.spark.unsafe.Platform

Expand All @@ -33,6 +34,14 @@ import org.apache.spark.unsafe.Platform
class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection {
import InterpretedUnsafeProjection._

private[this] val subExprElimination = SQLConf.get.subexpressionEliminationEnabled
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: subExprElimination -> subExprEliminationEnabled?

private[this] lazy val runtime = new EvaluationRunTime()
private[this] val proxyExpressions = if (subExprElimination) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: In case of subExprElimination=false or no common subexprs, this variable is not related to a "proxy", so how about just calling it exprs? (proxyExpressions -> exprs)

runtime.proxyExpressions(expressions)
} else {
expressions.toSeq
}

/** Number of (top level) fields in the resulting row. */
private[this] val numFields = expressions.length

Expand Down Expand Up @@ -63,17 +72,21 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
}

override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
proxyExpressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
})
}

override def apply(row: InternalRow): UnsafeRow = {
if (subExprElimination) {
runtime.setInput(row)
}

// Put the expression results in the intermediate row.
var i = 0
while (i < numFields) {
values(i) = expressions(i).eval(row)
values(i) = proxyExpressions(i).eval(row)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to apply this into InterpretedMutableProjection, too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be applied to InterpretedMutableProjection too. We can add it gradually like filter.

i += 1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES =
buildConf("spark.sql.subexpressionElimination.cache.maxEntries")
.internal()
.doc("The maximum entries of the cache used for interpreted subexpression elimination.")
.version("3.1.0")
.intConf
.checkValue(maxEntries => maxEntries >= 0, "The maximum must not be negative")
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
.createWithDefault(100)
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved

val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive")
.internal()
.doc("Whether the query analyzer should be case sensitive or not. " +
Expand Down Expand Up @@ -3214,6 +3223,9 @@ class SQLConf extends Serializable with Logging {
def subexpressionEliminationEnabled: Boolean =
getConf(SUBEXPRESSION_ELIMINATION_ENABLED)

def subexpressionEliminationCacheMaxEntries: Int =
getConf(SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES)

def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)

def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite

class EvaluationRunTimeSuite extends SparkFunSuite {

test("Evaluate ExpressionProxy should create cached result") {
val runtime = new EvaluationRunTime()
val proxy = ExpressionProxy(Literal(1), runtime)
assert(runtime.cache.size() == 0)
proxy.eval()
assert(runtime.cache.size() == 1)
assert(runtime.cache.get(proxy) == ResultProxy(1))
}

test("setInput should empty cached result") {
val runtime = new EvaluationRunTime()
val proxy1 = ExpressionProxy(Literal(1), runtime)
assert(runtime.cache.size() == 0)
proxy1.eval()
assert(runtime.cache.size() == 1)
assert(runtime.cache.get(proxy1) == ResultProxy(1))

val proxy2 = ExpressionProxy(Literal(2), runtime)
proxy2.eval()
assert(runtime.cache.size() == 2)
assert(runtime.cache.get(proxy2) == ResultProxy(2))

runtime.setInput()
assert(runtime.cache.size() == 0)
}

test("Wrap ExpressionProxy on subexpressions") {
val runtime = new EvaluationRunTime()

val one = Literal(1)
val two = Literal(2)
val mul = Multiply(one, two)
val mul2 = Multiply(mul, mul)
val sqrt = Sqrt(mul2)
val sum = Add(mul2, sqrt)

// ( (one * two) * (one * two) ) + sqrt( (one * two) * (one * two) )
val proxyExpressions = runtime.proxyExpressions(Seq(sum))
val proxys = proxyExpressions.flatMap(_.collect {
case p: ExpressionProxy => p
})
// ( (one * two) * (one * two) )
assert(proxys.size == 1)
val expected = ExpressionProxy(mul2, runtime)
assert(proxys.head == expected)
}

test("ExpressionProxy won't be on non deterministic") {
val runtime = new EvaluationRunTime()

val sum = Add(Rand(0), Rand(0))
val proxys = runtime.proxyExpressions(Seq(sum, sum)).flatMap(_.collect {
case p: ExpressionProxy => p
})
assert(proxys.isEmpty)
}
}