Skip to content

Commit

Permalink
Make Literals foldable, ensure Parquet predicates pushdown (#721)
Browse files Browse the repository at this point in the history
* #343 - unpack to Literals

* #343 - add struct test showing difference between extension and experimental rules

* #343 - toString test to stop the patch complaint

* #343 - sample docs

* #343 - package rename and adding logging that the extension is injected

* Apply suggestions from code review

Co-authored-by: Cédric Chantepie <cchantep@users.noreply.github.com>

* Refactor LitRule and LitRules tests by making them slightly more generic, adjust docs, add negative tests

* #343 - disable the rule, foldable and eval evals

* #343 - cleaned up

* #343 - true with link for 3.2 support

* #343 - bring back code gen with lazy to stop recompiles

* #343 - more compat and a foldable only backport of SPARK-39106 and SPARK-40380

* #343 - option 3 - let 3.2 fail as per oss impl, seperate tests

---------

Co-authored-by: Cédric Chantepie <cchantep@users.noreply.github.com>
Co-authored-by: Grigory Pomadchin <grigory.pomadchin@disneystreaming.com>
  • Loading branch information
3 people authored Jun 10, 2023
1 parent e257c4c commit dec676b
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 16 deletions.
8 changes: 7 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ val shapeless = "2.3.10"
val scalacheck = "1.17.0"
val scalacheckEffect = "1.0.4"
val refinedVersion = "0.10.3"
val nakedFSVersion = "0.1.0"

val Scala212 = "2.12.17"
val Scala213 = "2.13.10"
Expand Down Expand Up @@ -66,6 +67,7 @@ lazy val `cats-spark32` = project
lazy val dataset = project
.settings(name := "frameless-dataset")
.settings(Compile / unmanagedSourceDirectories += baseDirectory.value / "src" / "main" / "spark-3.4+")
.settings(Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "test" / "spark-3.3+")
.settings(datasetSettings)
.settings(sparkDependencies(sparkVersion))
.dependsOn(core % "test->test;compile->compile")
Expand All @@ -74,6 +76,7 @@ lazy val `dataset-spark33` = project
.settings(name := "frameless-dataset-spark33")
.settings(sourceDirectory := (dataset / sourceDirectory).value)
.settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3")
.settings(Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.3+")
.settings(datasetSettings)
.settings(sparkDependencies(spark33Version))
.settings(spark33Settings)
Expand All @@ -83,6 +86,7 @@ lazy val `dataset-spark32` = project
.settings(name := "frameless-dataset-spark32")
.settings(sourceDirectory := (dataset / sourceDirectory).value)
.settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3")
.settings(Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.2")
.settings(datasetSettings)
.settings(sparkDependencies(spark32Version))
.settings(spark32Settings)
Expand Down Expand Up @@ -192,7 +196,9 @@ lazy val datasetSettings = framelessSettings ++ framelessTypedDatasetREPL ++ Seq
dmm("org.apache.spark.sql.FramelessInternals.column")
)
},
coverageExcludedPackages := "org.apache.spark.sql.reflection"
coverageExcludedPackages := "org.apache.spark.sql.reflection",

libraryDependencies += "com.globalmentor" % "hadoop-bare-naked-local-fs" % nakedFSVersion % Test exclude("org.apache.hadoop", "hadoop-commons")
)

lazy val refinedSettings = framelessSettings ++ framelessTypedDatasetREPL ++ Seq(
Expand Down
18 changes: 11 additions & 7 deletions dataset/src/main/scala/frameless/functions/Lit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ import org.apache.spark.sql.types.DataType
private[frameless] case class Lit[T <: AnyVal](
dataType: DataType,
nullable: Boolean,
toCatalyst: CodegenContext => ExprCode,
show: () => String
show: () => String,
catalystExpr: Expression // must be a generated Expression from a literal TypedEncoder's toCatalyst function
) extends Expression with NonSQLExpression {
override def toString: String = s"FramelessLit(${show()})"

def eval(input: InternalRow): Any = {
lazy val codegen = {
val ctx = new CodegenContext()
val eval = genCode(ctx)

val codeBody = s"""
val codeBody =
s"""
public scala.Function1<InternalRow, Object> generate(Object[] references) {
return new LiteralEvalImpl(references);
}
Expand Down Expand Up @@ -47,13 +48,16 @@ private[frameless] case class Lit[T <: AnyVal](
val (clazz, _) = CodeGenerator.compile(code)
val codegen =
clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef]

codegen(input)
codegen
}

def eval(input: InternalRow): Any = codegen(input)

def children: Seq[Expression] = Nil

protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = toCatalyst(ctx)
protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = catalystExpr.genCode(ctx)

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = this

override val foldable: Boolean = catalystExpr.foldable
}
12 changes: 6 additions & 6 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ package object functions extends Udf with UnaryFunctions {
Lit(
dataType = encoder.catalystRepr,
nullable = encoder.nullable,
toCatalyst = encoder.toCatalyst(expr).genCode(_),
show = () => value.toString
show = () => value.toString,
catalystExpr = encoder.toCatalyst(expr)
)
)
}
Expand Down Expand Up @@ -84,8 +84,8 @@ package object functions extends Udf with UnaryFunctions {
Lit(
dataType = i7.catalystRepr,
nullable = i7.nullable,
toCatalyst = i7.toCatalyst(expr).genCode(_),
show = () => value.toString
show = () => value.toString,
i7.toCatalyst(expr)
)
)
}
Expand Down Expand Up @@ -127,8 +127,8 @@ package object functions extends Udf with UnaryFunctions {
Lit(
dataType = i7.catalystRepr,
nullable = true,
toCatalyst = i7.toCatalyst(expr).genCode(_),
show = () => value.toString
show = () => value.toString,
i7.toCatalyst(expr)
)
)
}
Expand Down
7 changes: 6 additions & 1 deletion dataset/src/test/scala/frameless/LitTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ class LitTests extends TypedDatasetSuite with Matchers {

val someIpsum: Option[Name] = Some(new Name("Ipsum"))

ds.withColumnReplaced('alias, functions.litValue(someIpsum)).
val lit = functions.litValue(someIpsum)
val tds = ds.withColumnReplaced('alias, functions.litValue(someIpsum))

tds.queryExecution.toString() should include (lit.toString)

tds.
collect.run() shouldBe initial.map(_.copy(alias = someIpsum))

ds.withColumnReplaced('alias, functions.litValue(Option.empty[Name])).
Expand Down
22 changes: 21 additions & 1 deletion dataset/src/test/scala/frameless/TypedDatasetSuite.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
package frameless

import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem
import org.apache.hadoop.fs.local.StreamingFS
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.scalactic.anyvals.PosZInt
import org.scalatest.BeforeAndAfterAll
import org.scalatestplus.scalacheck.Checkers
import org.scalacheck.Prop
import org.scalacheck.Prop._

import scala.util.{Properties, Try}
import org.scalatest.funsuite.AnyFunSuite

trait SparkTesting { self: BeforeAndAfterAll =>

val appID: String = new java.util.Date().toString + math.floor(math.random * 10E4).toLong.toString

val conf: SparkConf = new SparkConf()
/**
* Allows bare naked to be used instead of winutils for testing / dev
*/
def registerFS(sparkConf: SparkConf): SparkConf = {
if (System.getProperty("os.name").startsWith("Windows"))
sparkConf.set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName).
set("spark.hadoop.fs.AbstractFileSystem.file.impl", classOf[StreamingFS].getName)
else
sparkConf
}

val conf: SparkConf = registerFS(new SparkConf())
.setMaster("local[*]")
.setAppName("test")
.set("spark.ui.enabled", "false")
Expand All @@ -26,9 +40,15 @@ trait SparkTesting { self: BeforeAndAfterAll =>
implicit def sc: SparkContext = session.sparkContext
implicit def sqlContext: SQLContext = session.sqlContext

def registerOptimizations(sqlContext: SQLContext): Unit = { }

def addSparkConfigProperties(config: SparkConf): Unit = { }

override def beforeAll(): Unit = {
assert(s == null)
addSparkConfigProperties(conf)
s = SparkSession.builder().config(conf).getOrCreate()
registerOptimizations(sqlContext)
}

override def afterAll(): Unit = {
Expand Down
20 changes: 20 additions & 0 deletions dataset/src/test/scala/frameless/sql/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package frameless

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{And, Or}

package object sql {
implicit class ExpressionOps(val self: Expression) extends AnyVal {
def toList: List[Expression] = {
def rec(expr: Expression, acc: List[Expression]): List[Expression] = {
expr match {
case And(left, right) => rec(left, rec(right, acc))
case Or(left, right) => rec(left, rec(right, acc))
case e => e +: acc
}
}

rec(self, Nil)
}
}
}
74 changes: 74 additions & 0 deletions dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package frameless.sql.rules

import frameless._
import frameless.sql._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers

trait SQLRulesSuite extends TypedDatasetSuite with Matchers { self =>
protected lazy val path: String = {
val tmpDir = System.getProperty("java.io.tmpdir")
s"$tmpDir/${self.getClass.getName}"
}

def withDataset[A: TypedEncoder: CatalystOrdered](payload: A)(f: TypedDataset[A] => Assertion): Assertion = {
TypedDataset.create(Seq(payload)).write.mode("overwrite").parquet(path)
f(TypedDataset.createUnsafe[A](session.read.parquet(path)))
}

def predicatePushDownTest[A: TypedEncoder: CatalystOrdered](
expected: X1[A],
expectedPushDownFilters: List[Filter],
planShouldNotContain: PartialFunction[Expression, Expression],
op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean]
): Assertion = {
withDataset(expected) { dataset =>
val ds = dataset.filter(op(dataset('a)))
val actualPushDownFilters = pushDownFilters(ds)

val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList)

// check the optimized plan
optimizedPlan.collectFirst(planShouldNotContain) should be (empty)

// compare filters
actualPushDownFilters shouldBe expectedPushDownFilters

val actual = ds.collect().run().toVector.headOption

// ensure serialization is not broken
actual should be(Some(expected))
}
}

protected def pushDownFilters[T](ds: TypedDataset[T]): List[Filter] = {
val sparkPlan = ds.queryExecution.executedPlan

val initialPlan =
if (sparkPlan.children.isEmpty) // assume it's AQE
sparkPlan match {
case aq: AdaptiveSparkPlanExec => aq.initialPlan
case _ => sparkPlan
}
else
sparkPlan

initialPlan.collect {
case fs: FileSourceScanExec =>
import scala.reflect.runtime.{universe => ru}

val runtimeMirror = ru.runtimeMirror(getClass.getClassLoader)
val instanceMirror = runtimeMirror.reflect(fs)
val getter = ru.typeOf[FileSourceScanExec].member(ru.TermName("pushedDownFilters")).asTerm.getter
val m = instanceMirror.reflectMethod(getter.asMethod)
val res = m.apply(fs).asInstanceOf[Seq[Filter]]

res
}.flatten.toList
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.apache.hadoop.fs.local

import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem
import org.apache.hadoop.fs.DelegateToFileSystem

class StreamingFS(uri: java.net.URI, conf: org.apache.hadoop.conf.Configuration) extends
DelegateToFileSystem(uri, new BareLocalFileSystem(), conf, "file", false) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package frameless.sql.rules

import frameless._
import frameless.sql._
import frameless.functions.Lit
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant}
import org.apache.spark.sql.sources.{Filter, IsNotNull}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericRowWithSchema}
import java.time.Instant

import org.apache.spark.sql.catalyst.plans.logical
import org.scalatest.Assertion

//Note as InvokeLike and "ConditionalExpression" don't have SPARK-40380 and SPARK-39106 no predicate pushdowns can happen in 3.2.4
class FramelessLitPushDownTests extends SQLRulesSuite {
private val now: Long = currentTimestamp()

test("java.sql.Timestamp push-down") {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(SQLTimestamp(now))
val expectedPushDownFilters = List(IsNotNull("a"))

predicatePushDownTest[SQLTimestamp](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e },
_ >= expectedStructure.a
)
}

test("java.time.Instant push-down") {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(microsToInstant(now))
val expectedPushDownFilters = List(IsNotNull("a"))

predicatePushDownTest[Instant](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e },
_ >= expectedStructure.a
)
}

test("struct push-down") {
type Payload = X4[Int, Int, Int, Int]
val expectedStructure = X1(X4(1, 2, 3, 4))
val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema)
val expectedPushDownFilters = List(IsNotNull("a"))

predicatePushDownTest[Payload](
expectedStructure,
expectedPushDownFilters,
// Cast not Lit because of SPARK-40380
{ case e @ expressions.EqualTo(_, _: Cast) => e },
_ === expectedStructure.a
)
}

override def predicatePushDownTest[A: TypedEncoder: CatalystOrdered](
expected: X1[A],
expectedPushDownFilters: List[Filter],
planShouldContain: PartialFunction[Expression, Expression],
op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean]
): Assertion = {
withDataset(expected) { dataset =>
val ds = dataset.filter(op(dataset('a)))
val actualPushDownFilters = pushDownFilters(ds)

val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList)

// check the optimized plan
optimizedPlan.collectFirst(planShouldContain) should not be (empty)

// compare filters
actualPushDownFilters shouldBe expectedPushDownFilters

val actual = ds.collect().run().toVector.headOption

// ensure serialization is not broken
actual should be(Some(expected))
}
}

}
Loading

0 comments on commit dec676b

Please sign in to comment.