Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Literals foldable, ensure Parquet predicates pushdown #721

Merged
merged 27 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3bbdb9c
#343 - unpack to Literals
chris-twiner Jun 5, 2023
3df02ec
#343 - unpack to Literals - more test
chris-twiner Jun 5, 2023
c8ecea8
#343 - unpack to Literals - comment
chris-twiner Jun 5, 2023
b7c3132
#343 - per review - docs missing
chris-twiner Jun 5, 2023
81d9315
#343 - per review - docs missing - fix reflection for all versions
chris-twiner Jun 5, 2023
a3567c2
#343 - add struct test showing difference between extension and exper…
chris-twiner Jun 6, 2023
bee3cd0
#343 - toString test to stop the patch complaint
chris-twiner Jun 6, 2023
bba92cb
#343 - sample docs
chris-twiner Jun 6, 2023
28bde88
#343 - package rename and adding logging that the extension is injected
chris-twiner Jun 6, 2023
f4e99b5
#343 - doc fixes
chris-twiner Jun 6, 2023
0cbe684
#343 - doc fixes
chris-twiner Jun 6, 2023
c308241
#343 - can't run that code
chris-twiner Jun 6, 2023
381931c
#343 - didn't stop the new sparkSession
chris-twiner Jun 6, 2023
3df725f
Apply suggestions from code review
chris-twiner Jun 6, 2023
23c3eb7
#343 - more z's, debug removal, comment adjust and brackets around ex…
chris-twiner Jun 6, 2023
2a83510
Refactor LitRule and LitRules tests by making them slightly more gene…
pomadchin Jun 7, 2023
e7ba599
Fix mdoc compilation
pomadchin Jun 7, 2023
e9999c1
#343 - added the struct test back
chris-twiner Jun 7, 2023
4e7bee3
#343 - disable the rule, foldable and eval evals
chris-twiner Jun 7, 2023
27e7c25
#343 - cleaned up
chris-twiner Jun 7, 2023
18f2bc6
More code cleanup
pomadchin Jun 7, 2023
82bf013
#343 - true with link for 3.2 support
chris-twiner Jun 7, 2023
c6bbe2c
#343 - bring back code gen with lazy to stop recompiles
chris-twiner Jun 7, 2023
31a023f
#343 - disable tests on 3.2, document why and renable the proper fold…
chris-twiner Jun 8, 2023
d7db649
#343 - more compat and a foldable only backport of SPARK-39106 and SP…
chris-twiner Jun 8, 2023
0e6c561
#343 - option 3 - let 3.2 fail as per oss impl, seperate tests
chris-twiner Jun 8, 2023
411871b
#343 - option 3 - better dir names
chris-twiner Jun 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ lazy val dataset = project
lazy val `dataset-spark33` = project
.settings(name := "frameless-dataset-spark33")
.settings(sourceDirectory := (dataset / sourceDirectory).value)
.settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3")
.settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-33")
.settings(datasetSettings)
.settings(sparkDependencies(spark33Version))
.settings(spark33Settings)
Expand All @@ -83,7 +83,7 @@ lazy val `dataset-spark33` = project
lazy val `dataset-spark32` = project
.settings(name := "frameless-dataset-spark32")
.settings(sourceDirectory := (dataset / sourceDirectory).value)
.settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3")
.settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-32")
.settings(datasetSettings)
.settings(sparkDependencies(spark32Version))
.settings(spark32Settings)
Expand Down
6 changes: 4 additions & 2 deletions dataset/src/main/scala/frameless/functions/Lit.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package frameless.functions

import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression}
Expand Down Expand Up @@ -59,7 +60,8 @@ private[frameless] case class Lit[T <: AnyVal](

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = this

// see https://github.com/typelevel/frameless/pull/721#issuecomment-1581137730 (InvokeLike <3.3.1 SPARK-40380)
// see https://github.com/typelevel/frameless/pull/721#issuecomment-1581137730 (InvokeLike <3.3.1 SPARK-40380, ConditionalExpression SPARK-39106)
// for why this does not push down on 3.2, 3.3.1 and higher _do_ pushdown
override val foldable: Boolean = catalystExpr.foldable
// TODO remove the compat layer once 3.2 is no longer supported
override val foldable: Boolean = FramelessInternals.foldableCompat(catalystExpr)
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.apache.spark.sql

import frameless.FoldableImpl
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct}
Expand All @@ -10,6 +11,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.ObjectType

import scala.reflect.ClassTag

object FramelessInternals {
Expand Down Expand Up @@ -70,4 +72,11 @@ object FramelessInternals {
protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = tagged.genCode(ctx)
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head)
}

/**
* With 3.2 it's a backport of SPARK-40380, higher versions are simply expression.foldable
* @param expression
* @return
*/
def foldableCompat(expression: Expression): Boolean = FoldableImpl.foldableCompat(expression)
}
7 changes: 7 additions & 0 deletions dataset/src/main/spark-3.4+/frameless/FoldableImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package frameless

import org.apache.spark.sql.catalyst.expressions.Expression

object FoldableImpl {
def foldableCompat(expression: Expression): Boolean = expression.foldable
}
54 changes: 54 additions & 0 deletions dataset/src/main/spark-32/frameless/FoldableImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package frameless

import org.apache.spark.sql.FramelessSpark32Internals
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Coalesce, Expression, If, NaNvl, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.objects.InvokeLike
import org.apache.spark.sql.types.{DataType, ObjectType}

object FoldableImpl {
pomadchin marked this conversation as resolved.
Show resolved Hide resolved

trait ExpressionProxy {

def child: Expression

protected def withNewChildInternal(newChild: Expression): Expression = ???

protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ???

def dataType: DataType = child.dataType

}

// backported from SPARK-40380
case class InvokeLikeImpl(child: InvokeLike) extends UnaryExpression with ExpressionProxy {
// Returns true if we can trust all values of the given DataType can be serialized.
def trustedSerializable(dt: DataType): Boolean = {
// Right now we conservatively block all ObjectType (Java objects) regardless of
// serializability, because the type-level info with java.io.Serializable and
// java.io.Externalizable marker interfaces are not strong guarantees.
// This restriction can be relaxed in the future to expose more optimizations.
!FramelessSpark32Internals.existsRecursively(dt)(_.isInstanceOf[ObjectType])
}

override def foldable =
child.children.forall(_.foldable) && deterministic && trustedSerializable(dataType)
}

// foldable not implemented in 3.2, is in 3.3 (SPARK-39106)
case class ConditionalExpressionImpl(child: Expression) extends UnaryExpression with ExpressionProxy {
override def foldable =
child.children.forall(_.foldable)
}

// needed as we cannot test foldable on any parent expression if they have Invoke
// but similarly we cannot assume the parent is foldable - so we replace InvokeLike
def replaced(expression: Expression): Expression = expression transformUp {
case il: InvokeLike => InvokeLikeImpl(il)
case e@( _: If | _: CaseWhen | _: Coalesce | _: NaNvl ) =>
ConditionalExpressionImpl(e)
}

def foldableCompat(expression: Expression): Boolean =
replaced(expression).foldable
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.apache.spark.sql

import org.apache.spark.sql.types.DataType

object FramelessSpark32Internals {

/**
* Returns true if any `DataType` of this DataType tree satisfies the given function `f`.
*/
def existsRecursively(dt: DataType)(f: (DataType) => Boolean): Boolean = dt.existsRecursively(f)
}
7 changes: 7 additions & 0 deletions dataset/src/main/spark-33/frameless/FoldableImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package frameless

import org.apache.spark.sql.catalyst.expressions.Expression

object FoldableImpl {
def foldableCompat(expression: Expression): Boolean = expression.foldable
}
14 changes: 14 additions & 0 deletions dataset/src/main/spark-33/frameless/MapGroups.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package frameless

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MapGroups => SMapGroups}

object MapGroups {
def apply[K: Encoder, T: Encoder, U: Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
child: LogicalPlan
): LogicalPlan = SMapGroups(func, groupingAttributes, dataAttributes, child)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,51 @@ package frameless.sql.rules
import frameless._
import frameless.functions.Lit
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant}
import org.apache.spark.sql.sources.{GreaterThanOrEqual, IsNotNull, EqualTo}
import org.apache.spark.sql.sources.{EqualTo, GreaterThanOrEqual, IsNotNull}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema

import java.time.Instant

class FramelessLitPushDownTests extends SQLRulesSuite {
private val now: Long = currentTimestamp()

val invokeShouldFoldOnHigherThan3_2 = not3_2 _

test("java.sql.Timestamp push-down") (
invokeShouldFoldOnHigherThan3_2 {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(SQLTimestamp(now))
val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected))

predicatePushDownTest[SQLTimestamp](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e },
_ >= expectedStructure.a
)
}
)

test("java.time.Instant push-down") (
invokeShouldFoldOnHigherThan3_2 {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(microsToInstant(now))
val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected))

predicatePushDownTest[Instant](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e },
_ >= expectedStructure.a
)
}
)

test("struct push-down") (
invokeShouldFoldOnHigherThan3_2 {
type Payload = X4[Int, Int, Int, Int]
val expectedStructure = X1(X4(1, 2, 3, 4))
val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema)
val expectedPushDownFilters = List(IsNotNull("a"), EqualTo("a", expected))

predicatePushDownTest[Payload](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.EqualTo(_, _: Lit[_]) => e },
_ === expectedStructure.a
)
}
)
test("java.sql.Timestamp push-down") {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(SQLTimestamp(now))
val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected))

predicatePushDownTest[SQLTimestamp](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e },
_ >= expectedStructure.a
)
}

test("java.time.Instant push-down") {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(microsToInstant(now))
val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected))

predicatePushDownTest[Instant](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e },
_ >= expectedStructure.a
)
}

test("struct push-down") {
type Payload = X4[Int, Int, Int, Int]
val expectedStructure = X1(X4(1, 2, 3, 4))
val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema)
val expectedPushDownFilters = List(IsNotNull("a"), EqualTo("a", expected))

predicatePushDownTest[Payload](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.EqualTo(_, _: Lit[_]) => e },
_ === expectedStructure.a
)
}
}
22 changes: 0 additions & 22 deletions dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package frameless.sql.rules

import frameless._
import frameless.sql._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.catalyst.plans.logical
Expand All @@ -17,27 +16,6 @@ trait SQLRulesSuite extends TypedDatasetSuite with Matchers { self =>
s"$tmpDir/${self.getClass.getName}"
}

lazy val sparkFullVersion = {
val pos = classOf[Expression].getPackage.getSpecificationVersion
if (pos eq null) // DBR is always null
SparkSession.active.version // taking a running spark version string, hence lazy
else
pos
}

lazy val sparkVersion = {
sparkFullVersion.split('.').take(2).mkString(".")
}

/**
* Don't run this test on 3.2
*/
def not3_2[T](thunk: => T): Any =
if (sparkVersion != "3.2")
thunk
else
()

def withDataset[A: TypedEncoder: CatalystOrdered](payload: A)(f: TypedDataset[A] => Assertion): Assertion = {
TypedDataset.create(Seq(payload)).write.mode("overwrite").parquet(path)
f(TypedDataset.createUnsafe[A](session.read.parquet(path)))
Expand Down