Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/scala 2.13 for spark 3.3.x and 3.2.x #238

Merged
merged 6 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/actions/check_build_and_doc/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ runs:
# TESTING & COVERAGE
- name: Test & coverage 📋
if: contains(env.SCOPE, 'test')
run: sbt -DsparkVersion=${{ env.SPARK_VERSION }} coverage core/test coverageReport coverageAggregate
run: sbt -DsparkVersion=${{ env.SPARK_VERSION }} coverage +core/test coverageReport coverageAggregate
shell: bash

- name: Publish coverage to codecov 📊
if: contains(env.SCOPE, 'uploadReport')
uses: codecov/codecov-action@v3
with:
files: ./target/scala-2.11/scoverage-report/scoverage.xml,./target/scala-2.12/scoverage-report/scoverage.xml,./target/scala-2.13/scoverage-report/scoverage.xml
files: ./target/scala-2.11/scoverage-report/scoverage.xml,./target/scala-2.12/scoverage-report/scoverage.xml
fail_ci_if_error: true
verbose: false
flags: 'spark-${{ env.SPARK_VERSION }}.x'
Expand Down
66 changes: 44 additions & 22 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ val scalaVersionSelect: String => List[String] = {
case versionRegex("2", _, _) => List(scala211)
case versionRegex("3", "0", _) => List(scala212)
case versionRegex("3", "1", _) => List(scala212)
case versionRegex("3", "2", _) => List(scala212)
case versionRegex("3", "3", _) => List(scala212)
case versionRegex("3", "2", _) => List(scala212, scala213)
case versionRegex("3", "3", _) => List(scala212, scala213)

}

Expand Down Expand Up @@ -102,6 +102,18 @@ val configSpark = Seq(
)
)

val scalaOptionsCommon = Seq(
"-encoding",
"utf8", // Option and arguments on same line
"-Xfatal-warnings", // New lines for each options
"-deprecation",
"-unchecked",
"-language:implicitConversions",
"-language:higherKinds",
"-language:existentials",
"-language:postfixOps",
"-Ywarn-numeric-widen"
)
lazy val core = project
.in(file("core"))
.settings(
Expand Down Expand Up @@ -129,15 +141,31 @@ lazy val core = project
"org.apache.spark"
),
Compile / unmanagedSourceDirectories ++= {
sparkVersion.value match {
(sparkVersion.value match {
case versionRegex(mayor, minor, _) =>
(Compile / sourceDirectory).value ** s"spark_*$mayor.$minor*" / "scala" get
}
(Compile / sourceDirectory).value ** s"*spark_*$mayor.$minor*" / "scala" get
}) ++
(scalaVersion.value match {
case versionRegex(mayor, minor, _) =>
(Compile / sourceDirectory).value ** s"*scala_*$mayor.$minor*" / "scala" get
})
},
Test / unmanagedSourceDirectories ++= {
sparkVersion.value match {
(sparkVersion.value match {
case versionRegex(mayor, minor, _) =>
(Test / sourceDirectory).value ** s"spark_*$mayor.$minor*" / "scala" get
(Test / sourceDirectory).value ** s"*spark_*$mayor.$minor*" / "scala" get
}) ++
(scalaVersion.value match {
case versionRegex(mayor, minor, _) =>
(Test / sourceDirectory).value ** s"*scala_*$mayor.$minor*" / "scala" get
})
},
scalacOptions ++= {
scalaOptionsCommon ++ {
if (scalaVersion.value.startsWith("2.13"))
Seq.empty
else
Seq("-Ypartial-unification")
}
}
)
Expand Down Expand Up @@ -170,24 +198,18 @@ lazy val docs = project
),
mdocExtraArguments := Seq(
"--clean-target"
)
),
scalacOptions ++= {
scalaOptionsCommon ++ {
if (scalaVersion.value.startsWith("2.13"))
Seq.empty
else
Seq("-Ypartial-unification")
}
}
)
.enablePlugins(plugins: _*)

Global / scalacOptions ++= Seq(
"-encoding",
"utf8", // Option and arguments on same line
"-Xfatal-warnings", // New lines for each options
"-deprecation",
"-unchecked",
"-language:implicitConversions",
"-language:higherKinds",
"-language:existentials",
"-language:postfixOps",
"-Ypartial-unification",
"-Ywarn-numeric-widen"
)

// Scoverage settings
Global / coverageEnabled := false
Global / coverageFailOnMinimum := false
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/doric/sem/TransformOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private[sem] trait TransformOps {
def withColumns(
namesAndCols: (String, DoricColumn[_])*
): DataFrame = {
if (namesAndCols.isEmpty) df.toDF
if (namesAndCols.isEmpty) df.toDF()
else
namesAndCols.toList
.traverse(_._2.elem)
Expand All @@ -64,7 +64,7 @@ private[sem] trait TransformOps {
def withColumns(
namesAndCols: Map[String, DoricColumn[_]]
): DataFrame = {
if (namesAndCols.isEmpty) df.toDF
if (namesAndCols.isEmpty) df.toDF()
else
withColumns(namesAndCols.toList: _*)
}
Expand All @@ -79,7 +79,7 @@ private[sem] trait TransformOps {
def withNamedColumns(
namedColumns: NamedDoricColumn[_]*
): DataFrame = {
if (namedColumns.isEmpty) df.toDF
if (namedColumns.isEmpty) df.toDF()
else
withColumns(
namedColumns.iterator.map(x => (x.columnName, x)).toList: _*
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/doric/syntax/NumericColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ private[syntax] trait NumericColumns {
* @group Numeric Type
* @see [[org.apache.spark.sql.functions.spark_partition_id]]
*/
def sparkPartitionId(): IntegerColumn = DoricColumn(f.spark_partition_id)
def sparkPartitionId(): IntegerColumn = DoricColumn(f.spark_partition_id())

/**
* A column expression that generates monotonically increasing 64-bit integers.
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/doric/syntax/StringColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private[syntax] trait StringColumns {
* @group String Type
* @see [[org.apache.spark.sql.functions.input_file_name]]
*/
def inputFileName(): StringColumn = DoricColumn(f.input_file_name)
def inputFileName(): StringColumn = DoricColumn(f.input_file_name())

/**
* Creates a string column for the file name of the current Spark task.
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/doric/types/LiteralSparkType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ object LiteralSparkType {
override type OriginalSparkType = Array[lst.OriginalSparkType]

override val literalTo: Array[A] => OriginalSparkType = {
_.map(lst.literalTo).toArray(lst.classTag)
implicit val a = lst.classTag
_.map(lst.literalTo)
.toArray(lst.classTag)
}

override val classTag: ClassTag[Array[lst.OriginalSparkType]] =
Expand Down
5 changes: 2 additions & 3 deletions core/src/main/scala/doric/types/SparkType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package doric
package types

import scala.annotation.implicitNotFound
import scala.collection.mutable
import scala.reflect.ClassTag

import cats.data.{Kleisli, Validated}
Expand Down Expand Up @@ -222,7 +221,7 @@ object SparkType {
.toArray

override val rowFieldTransform: Any => Array[O] =
_.asInstanceOf[mutable.WrappedArray[O]].toArray
_.asInstanceOf[DoricArray.Collection[O]].toArray

}

Expand All @@ -245,7 +244,7 @@ object SparkType {
_.map(st.transform)

override val rowFieldTransform: Any => OriginalSparkType =
_.asInstanceOf[mutable.WrappedArray[st.OriginalSparkType]]
_.asInstanceOf[DoricArray.Collection[st.OriginalSparkType]]
.map(st.rowFieldTransform)
.toList
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package doric.types

import scala.collection.mutable

object DoricArray {
type Collection[T] = mutable.WrappedArray[T]
}
7 changes: 7 additions & 0 deletions core/src/main/scala_2.13/scala/doric/types/DoricArray.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package doric.types

import scala.collection.mutable

object DoricArray {
type Collection[T] = mutable.ArraySeq[T]
}
11 changes: 11 additions & 0 deletions core/src/test/scala/doric/sem/TransformOpsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,16 @@ class TransformOpsSpec
error.getMessage should include("`a`")
error.getMessage should include("`b`")
}

it("should work with 'withNamedColumns' as with 'namedColumns'") {
val df = spark
.range(2)
.withNamedColumns(1.lit.as("hi"), col[Long]("id") + 10 as "id")
df.columns shouldBe Array("id", "hi")
df.collectCols(col[Long]("id"), col[Int]("hi")) shouldBe List(
(10, 1),
(11, 1)
)
}
}
}
7 changes: 4 additions & 3 deletions core/src/test/scala/doric/syntax/DynamicSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ class DynamicSpec extends DoricTestElements with EitherValues with Matchers {
}

it("can get values from sub-sub-columns") {
List(((("1", 2.0), 2), true)).toDF
List(((("1", 2.0), 2), true))
.toDF()
.validateColumnType(colStruct("_1")._1[Row]._1[String])
}

it("can get values from the top-level row") {
df.validateColumnType(row.user[Row])
df.validateColumnType(row.user[Row].age[Int])
List(("1", 2, true)).toDF.validateColumnType(row._1[String])
List((("1", 2), true)).toDF.validateColumnType(row._1[Row]._2[Int])
List(("1", 2, true)).toDF().validateColumnType(row._1[String])
List((("1", 2), true)).toDF().validateColumnType(row._1[Row]._2[Int])
}

if (minorScalaVersion >= 12)
Expand Down