diff --git a/build.sbt b/build.sbt index e5ae59e5..39fd5283 100644 --- a/build.sbt +++ b/build.sbt @@ -10,6 +10,7 @@ val shapeless = "2.3.10" val scalacheck = "1.17.0" val scalacheckEffect = "1.0.4" val refinedVersion = "0.10.3" +val nakedFSVersion = "0.1.0" val Scala212 = "2.12.17" val Scala213 = "2.13.10" @@ -66,6 +67,7 @@ lazy val `cats-spark32` = project lazy val dataset = project .settings(name := "frameless-dataset") .settings(Compile / unmanagedSourceDirectories += baseDirectory.value / "src" / "main" / "spark-3.4+") + .settings(Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "test" / "spark-3.3+") .settings(datasetSettings) .settings(sparkDependencies(sparkVersion)) .dependsOn(core % "test->test;compile->compile") @@ -74,6 +76,7 @@ lazy val `dataset-spark33` = project .settings(name := "frameless-dataset-spark33") .settings(sourceDirectory := (dataset / sourceDirectory).value) .settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3") + .settings(Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.3+") .settings(datasetSettings) .settings(sparkDependencies(spark33Version)) .settings(spark33Settings) @@ -83,6 +86,7 @@ lazy val `dataset-spark32` = project .settings(name := "frameless-dataset-spark32") .settings(sourceDirectory := (dataset / sourceDirectory).value) .settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3") + .settings(Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.2") .settings(datasetSettings) .settings(sparkDependencies(spark32Version)) .settings(spark32Settings) @@ -192,7 +196,9 @@ lazy val datasetSettings = framelessSettings ++ framelessTypedDatasetREPL ++ Seq dmm("org.apache.spark.sql.FramelessInternals.column") ) }, - coverageExcludedPackages := "org.apache.spark.sql.reflection" + coverageExcludedPackages := "org.apache.spark.sql.reflection", + + libraryDependencies += "com.globalmentor" % "hadoop-bare-naked-local-fs" % nakedFSVersion % Test exclude("org.apache.hadoop", "hadoop-commons") ) lazy val refinedSettings = framelessSettings ++ framelessTypedDatasetREPL ++ Seq( diff --git a/dataset/src/main/scala/frameless/functions/Lit.scala b/dataset/src/main/scala/frameless/functions/Lit.scala index 3a816ff9..d01467b1 100644 --- a/dataset/src/main/scala/frameless/functions/Lit.scala +++ b/dataset/src/main/scala/frameless/functions/Lit.scala @@ -8,16 +8,17 @@ import org.apache.spark.sql.types.DataType private[frameless] case class Lit[T <: AnyVal]( dataType: DataType, nullable: Boolean, - toCatalyst: CodegenContext => ExprCode, - show: () => String + show: () => String, + catalystExpr: Expression // must be a generated Expression from a literal TypedEncoder's toCatalyst function ) extends Expression with NonSQLExpression { override def toString: String = s"FramelessLit(${show()})" - def eval(input: InternalRow): Any = { + lazy val codegen = { val ctx = new CodegenContext() val eval = genCode(ctx) - val codeBody = s""" + val codeBody = + s""" public scala.Function1 generate(Object[] references) { return new LiteralEvalImpl(references); } @@ -47,13 +48,16 @@ private[frameless] case class Lit[T <: AnyVal]( val (clazz, _) = CodeGenerator.compile(code) val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] - - codegen(input) + codegen } + def eval(input: InternalRow): Any = codegen(input) + def children: Seq[Expression] = Nil - protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = toCatalyst(ctx) + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = catalystExpr.genCode(ctx) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = this + + override val foldable: Boolean = catalystExpr.foldable } diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index 4e5fafc9..21b5bc99 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -45,8 +45,8 @@ package object functions extends Udf with UnaryFunctions { Lit( dataType = encoder.catalystRepr, nullable = encoder.nullable, - toCatalyst = encoder.toCatalyst(expr).genCode(_), - show = () => value.toString + show = () => value.toString, + catalystExpr = encoder.toCatalyst(expr) ) ) } @@ -84,8 +84,8 @@ package object functions extends Udf with UnaryFunctions { Lit( dataType = i7.catalystRepr, nullable = i7.nullable, - toCatalyst = i7.toCatalyst(expr).genCode(_), - show = () => value.toString + show = () => value.toString, + i7.toCatalyst(expr) ) ) } @@ -127,8 +127,8 @@ package object functions extends Udf with UnaryFunctions { Lit( dataType = i7.catalystRepr, nullable = true, - toCatalyst = i7.toCatalyst(expr).genCode(_), - show = () => value.toString + show = () => value.toString, + i7.toCatalyst(expr) ) ) } diff --git a/dataset/src/test/scala/frameless/LitTests.scala b/dataset/src/test/scala/frameless/LitTests.scala index 5dddd50c..50df4522 100644 --- a/dataset/src/test/scala/frameless/LitTests.scala +++ b/dataset/src/test/scala/frameless/LitTests.scala @@ -79,7 +79,12 @@ class LitTests extends TypedDatasetSuite with Matchers { val someIpsum: Option[Name] = Some(new Name("Ipsum")) - ds.withColumnReplaced('alias, functions.litValue(someIpsum)). + val lit = functions.litValue(someIpsum) + val tds = ds.withColumnReplaced('alias, functions.litValue(someIpsum)) + + tds.queryExecution.toString() should include (lit.toString) + + tds. collect.run() shouldBe initial.map(_.copy(alias = someIpsum)) ds.withColumnReplaced('alias, functions.litValue(Option.empty[Name])). diff --git a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala index 36739c67..8a469783 100644 --- a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala +++ b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala @@ -1,5 +1,7 @@ package frameless +import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem +import org.apache.hadoop.fs.local.StreamingFS import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.{SQLContext, SparkSession} import org.scalactic.anyvals.PosZInt @@ -7,6 +9,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatestplus.scalacheck.Checkers import org.scalacheck.Prop import org.scalacheck.Prop._ + import scala.util.{Properties, Try} import org.scalatest.funsuite.AnyFunSuite @@ -14,7 +17,18 @@ trait SparkTesting { self: BeforeAndAfterAll => val appID: String = new java.util.Date().toString + math.floor(math.random * 10E4).toLong.toString - val conf: SparkConf = new SparkConf() + /** + * Allows bare naked to be used instead of winutils for testing / dev + */ + def registerFS(sparkConf: SparkConf): SparkConf = { + if (System.getProperty("os.name").startsWith("Windows")) + sparkConf.set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName). + set("spark.hadoop.fs.AbstractFileSystem.file.impl", classOf[StreamingFS].getName) + else + sparkConf + } + + val conf: SparkConf = registerFS(new SparkConf()) .setMaster("local[*]") .setAppName("test") .set("spark.ui.enabled", "false") @@ -26,9 +40,15 @@ trait SparkTesting { self: BeforeAndAfterAll => implicit def sc: SparkContext = session.sparkContext implicit def sqlContext: SQLContext = session.sqlContext + def registerOptimizations(sqlContext: SQLContext): Unit = { } + + def addSparkConfigProperties(config: SparkConf): Unit = { } + override def beforeAll(): Unit = { assert(s == null) + addSparkConfigProperties(conf) s = SparkSession.builder().config(conf).getOrCreate() + registerOptimizations(sqlContext) } override def afterAll(): Unit = { diff --git a/dataset/src/test/scala/frameless/sql/package.scala b/dataset/src/test/scala/frameless/sql/package.scala new file mode 100644 index 00000000..fcb45b03 --- /dev/null +++ b/dataset/src/test/scala/frameless/sql/package.scala @@ -0,0 +1,20 @@ +package frameless + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{And, Or} + +package object sql { + implicit class ExpressionOps(val self: Expression) extends AnyVal { + def toList: List[Expression] = { + def rec(expr: Expression, acc: List[Expression]): List[Expression] = { + expr match { + case And(left, right) => rec(left, rec(right, acc)) + case Or(left, right) => rec(left, rec(right, acc)) + case e => e +: acc + } + } + + rec(self, Nil) + } + } +} diff --git a/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala b/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala new file mode 100644 index 00000000..8555d180 --- /dev/null +++ b/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala @@ -0,0 +1,74 @@ +package frameless.sql.rules + +import frameless._ +import frameless.sql._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.scalatest.Assertion +import org.scalatest.matchers.should.Matchers + +trait SQLRulesSuite extends TypedDatasetSuite with Matchers { self => + protected lazy val path: String = { + val tmpDir = System.getProperty("java.io.tmpdir") + s"$tmpDir/${self.getClass.getName}" + } + + def withDataset[A: TypedEncoder: CatalystOrdered](payload: A)(f: TypedDataset[A] => Assertion): Assertion = { + TypedDataset.create(Seq(payload)).write.mode("overwrite").parquet(path) + f(TypedDataset.createUnsafe[A](session.read.parquet(path))) + } + + def predicatePushDownTest[A: TypedEncoder: CatalystOrdered]( + expected: X1[A], + expectedPushDownFilters: List[Filter], + planShouldNotContain: PartialFunction[Expression, Expression], + op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean] + ): Assertion = { + withDataset(expected) { dataset => + val ds = dataset.filter(op(dataset('a))) + val actualPushDownFilters = pushDownFilters(ds) + + val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList) + + // check the optimized plan + optimizedPlan.collectFirst(planShouldNotContain) should be (empty) + + // compare filters + actualPushDownFilters shouldBe expectedPushDownFilters + + val actual = ds.collect().run().toVector.headOption + + // ensure serialization is not broken + actual should be(Some(expected)) + } + } + + protected def pushDownFilters[T](ds: TypedDataset[T]): List[Filter] = { + val sparkPlan = ds.queryExecution.executedPlan + + val initialPlan = + if (sparkPlan.children.isEmpty) // assume it's AQE + sparkPlan match { + case aq: AdaptiveSparkPlanExec => aq.initialPlan + case _ => sparkPlan + } + else + sparkPlan + + initialPlan.collect { + case fs: FileSourceScanExec => + import scala.reflect.runtime.{universe => ru} + + val runtimeMirror = ru.runtimeMirror(getClass.getClassLoader) + val instanceMirror = runtimeMirror.reflect(fs) + val getter = ru.typeOf[FileSourceScanExec].member(ru.TermName("pushedDownFilters")).asTerm.getter + val m = instanceMirror.reflectMethod(getter.asMethod) + val res = m.apply(fs).asInstanceOf[Seq[Filter]] + + res + }.flatten.toList + } +} diff --git a/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala b/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala new file mode 100644 index 00000000..a28ad082 --- /dev/null +++ b/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala @@ -0,0 +1,7 @@ +package org.apache.hadoop.fs.local + +import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem +import org.apache.hadoop.fs.DelegateToFileSystem + +class StreamingFS(uri: java.net.URI, conf: org.apache.hadoop.conf.Configuration) extends + DelegateToFileSystem(uri, new BareLocalFileSystem(), conf, "file", false) {} diff --git a/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala b/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala new file mode 100644 index 00000000..c44ac4d0 --- /dev/null +++ b/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala @@ -0,0 +1,85 @@ +package frameless.sql.rules + +import frameless._ +import frameless.sql._ +import frameless.functions.Lit +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant} +import org.apache.spark.sql.sources.{Filter, IsNotNull} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericRowWithSchema} +import java.time.Instant + +import org.apache.spark.sql.catalyst.plans.logical +import org.scalatest.Assertion + +//Note as InvokeLike and "ConditionalExpression" don't have SPARK-40380 and SPARK-39106 no predicate pushdowns can happen in 3.2.4 +class FramelessLitPushDownTests extends SQLRulesSuite { + private val now: Long = currentTimestamp() + + test("java.sql.Timestamp push-down") { + val expected = java.sql.Timestamp.from(microsToInstant(now)) + val expectedStructure = X1(SQLTimestamp(now)) + val expectedPushDownFilters = List(IsNotNull("a")) + + predicatePushDownTest[SQLTimestamp]( + expectedStructure, + expectedPushDownFilters, + { case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e }, + _ >= expectedStructure.a + ) + } + + test("java.time.Instant push-down") { + val expected = java.sql.Timestamp.from(microsToInstant(now)) + val expectedStructure = X1(microsToInstant(now)) + val expectedPushDownFilters = List(IsNotNull("a")) + + predicatePushDownTest[Instant]( + expectedStructure, + expectedPushDownFilters, + { case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e }, + _ >= expectedStructure.a + ) + } + + test("struct push-down") { + type Payload = X4[Int, Int, Int, Int] + val expectedStructure = X1(X4(1, 2, 3, 4)) + val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema) + val expectedPushDownFilters = List(IsNotNull("a")) + + predicatePushDownTest[Payload]( + expectedStructure, + expectedPushDownFilters, + // Cast not Lit because of SPARK-40380 + { case e @ expressions.EqualTo(_, _: Cast) => e }, + _ === expectedStructure.a + ) + } + + override def predicatePushDownTest[A: TypedEncoder: CatalystOrdered]( + expected: X1[A], + expectedPushDownFilters: List[Filter], + planShouldContain: PartialFunction[Expression, Expression], + op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean] + ): Assertion = { + withDataset(expected) { dataset => + val ds = dataset.filter(op(dataset('a))) + val actualPushDownFilters = pushDownFilters(ds) + + val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList) + + // check the optimized plan + optimizedPlan.collectFirst(planShouldContain) should not be (empty) + + // compare filters + actualPushDownFilters shouldBe expectedPushDownFilters + + val actual = ds.collect().run().toVector.headOption + + // ensure serialization is not broken + actual should be(Some(expected)) + } + } + +} diff --git a/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala b/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala new file mode 100644 index 00000000..36a443fb --- /dev/null +++ b/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala @@ -0,0 +1,53 @@ +package frameless.sql.rules + +import frameless._ +import frameless.functions.Lit +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant} +import org.apache.spark.sql.sources.{EqualTo, GreaterThanOrEqual, IsNotNull} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import java.time.Instant + +class FramelessLitPushDownTests extends SQLRulesSuite { + private val now: Long = currentTimestamp() + + test("java.sql.Timestamp push-down") { + val expected = java.sql.Timestamp.from(microsToInstant(now)) + val expectedStructure = X1(SQLTimestamp(now)) + val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) + + predicatePushDownTest[SQLTimestamp]( + expectedStructure, + expectedPushDownFilters, + { case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e }, + _ >= expectedStructure.a + ) + } + + test("java.time.Instant push-down") { + val expected = java.sql.Timestamp.from(microsToInstant(now)) + val expectedStructure = X1(microsToInstant(now)) + val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) + + predicatePushDownTest[Instant]( + expectedStructure, + expectedPushDownFilters, + { case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e }, + _ >= expectedStructure.a + ) + } + + test("struct push-down") { + type Payload = X4[Int, Int, Int, Int] + val expectedStructure = X1(X4(1, 2, 3, 4)) + val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema) + val expectedPushDownFilters = List(IsNotNull("a"), EqualTo("a", expected)) + + predicatePushDownTest[Payload]( + expectedStructure, + expectedPushDownFilters, + { case e @ expressions.EqualTo(_, _: Lit[_]) => e }, + _ === expectedStructure.a + ) + } +}