Skip to content

Commit

Permalink
Doric now always carries the schema in the columns. No more hidden se…
Browse files Browse the repository at this point in the history
…lects.
  • Loading branch information
alfonsorr committed Jun 15, 2022
1 parent f3339fe commit fa3c0a5
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 171 deletions.
12 changes: 11 additions & 1 deletion core/src/main/scala/doric/doric.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import cats.data.{EitherNec, Kleisli, ValidatedNec}
import cats.implicits._
import cats.Parallel
import cats.arrow.FunctionK
import doric.sem.DoricSingleError
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
Expand Down Expand Up @@ -31,7 +32,16 @@ package object doric extends syntax.All with sem.All {
Doric(org.apache.spark.sql.functions.col(colName.value))
}

private type DoricEither[A] = EitherNec[DoricSingleError, A]
private[doric] type DoricEither[A] = EitherNec[DoricSingleError, A]

private[doric] val toValidated = new FunctionK[DoricEither, DoricValidated] {
override def apply[A](fa: DoricEither[A]): DoricValidated[A] =
fa.toValidated
}
private[doric] val toEither = new FunctionK[DoricValidated, DoricEither] {
override def apply[A](fa: DoricValidated[A]): DoricEither[A] = fa.toEither
}

private type SequenceDoric[F] = Kleisli[DoricEither, Dataset[_], F]

implicit private[doric] class SeqPar[A](a: Doric[A])(implicit
Expand Down
11 changes: 0 additions & 11 deletions core/src/main/scala/doric/sem/Errors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,6 @@ case class ColumnTypeError(
s"The column with name '$columnName' was expected to be $expectedType but is of type $foundType"
}

case class ChildColumnNotFound(
columnName: String,
validColumns: Seq[String]
)(implicit
val location: Location
) extends DoricSingleError(None) {
override def message: String =
s"No such struct field $columnName among nested columns ${validColumns
.mkString("(", ", ", ")")}"
}

case class SparkErrorWrapper(sparkCause: Throwable)(implicit
val location: Location
) extends DoricSingleError(Some(sparkCause)) {
Expand Down
127 changes: 91 additions & 36 deletions core/src/main/scala/doric/syntax/ArrayColumns.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package doric
package syntax

import cats.data.Kleisli
import cats.implicits._
import doric.types.CollectionType

import org.apache.spark.sql.{Column, functions => f}
import org.apache.spark.sql.{Column, Dataset, functions => f}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.LambdaFunction.identity

Expand Down Expand Up @@ -67,7 +68,22 @@ private[syntax] trait ArrayColumns {
* the DoricColumn with the selected element.
*/
def getIndex(n: Int): DoricColumn[T] =
col.elem.map(_.apply(n)).toDC
(col.elem, n.lit.elem)
.mapN((a, b) => (a, b))
.mapK(toEither)
.flatMap { case (a, b) =>
Kleisli[DoricEither, Dataset[_], Column]((df: Dataset[_]) => {
new Column(
ExtractValue(
a.expr,
b.expr,
df.sparkSession.sessionState.analyzer.resolver
)
).asRight
})
}
.mapK(toValidated)
.toDC

/**
* Transform each element with the provided function.
Expand All @@ -84,10 +100,14 @@ private[syntax] trait ArrayColumns {
*/
def transform[A](
fun: DoricColumn[T] => DoricColumn[A]
): DoricColumn[F[A]] =
(col.elem, fun(x).elem, fun(col.getIndex(0)).elem)
.mapN((a, f, _) => new Column(ArrayTransform(a.expr, lam1(f.expr))))
): DoricColumn[F[A]] = {
val xv = x(col.getIndex(0))
(col.elem, fun(xv).elem, xv.elem)
.mapN((a, f, x) =>
new Column(ArrayTransform(a.expr, lam1(f.expr, x.expr)))
)
.toDC
}

/**
* Transform each element of the array with the provided function that
Expand All @@ -106,11 +126,18 @@ private[syntax] trait ArrayColumns {
*/
def transformWithIndex[A](
fun: (DoricColumn[T], IntegerColumn) => DoricColumn[A]
): DoricColumn[F[A]] =
(col.elem, fun(x, y).elem, fun(col.getIndex(0), 1.lit).elem).mapN {
(a, f, _) =>
new Column(ArrayTransform(a.expr, lam2(f.expr)))
): DoricColumn[F[A]] = {
val xv = x(col.getIndex(0))
val yv = y(1.lit)
(
col.elem,
fun(xv, yv).elem,
xv.elem,
yv.elem
).mapN { (a, f, x, y) =>
new Column(ArrayTransform(a.expr, lam2(f.expr, x.expr, y.expr)))
}.toDC
}

/**
* Aggregates (reduce) the array with the provided functions, similar to
Expand All @@ -135,17 +162,27 @@ private[syntax] trait ArrayColumns {
def aggregateWT[A, B](zero: DoricColumn[A])(
merge: (DoricColumn[A], DoricColumn[T]) => DoricColumn[A],
finish: DoricColumn[A] => DoricColumn[B]
): DoricColumn[B] =
): DoricColumn[B] = {
val xv = x(zero)
val yv = y(col.getIndex(0))
(
col.elem,
zero.elem,
merge(x, y).elem,
finish(x).elem,
merge(zero, col.getIndex(0)).elem,
finish(zero).elem
).mapN { (a, z, m, f, _, _) =>
new Column(ArrayAggregate(a.expr, z.expr, lam2(m.expr), lam1(f.expr)))
merge(xv, yv).elem,
finish(xv).elem,
xv.elem,
yv.elem
).mapN { (a, z, m, f, x, y) =>
new Column(
ArrayAggregate(
a.expr,
z.expr,
lam2(m.expr, x.expr, y.expr),
lam1(f.expr, x.expr)
)
)
}.toDC
}

/**
* Aggregates (reduce) the array with the provided functions, similar to
Expand All @@ -167,15 +204,21 @@ private[syntax] trait ArrayColumns {
zero: DoricColumn[A]
)(
merge: (DoricColumn[A], DoricColumn[T]) => DoricColumn[A]
): DoricColumn[A] =
): DoricColumn[A] = {
val xv = x(zero)
val yv = y(col.getIndex(0))
(
col.elem,
zero.elem,
merge(x, y).elem,
merge(zero, col.getIndex(0)).elem
).mapN { (a, z, m, _) =>
new Column(ArrayAggregate(a.expr, z.expr, lam2(m.expr), identity))
merge(xv, yv).elem,
xv.elem,
yv.elem
).mapN { (a, z, m, x, y) =>
new Column(
ArrayAggregate(a.expr, z.expr, lam2(m.expr, x.expr, y.expr), identity)
)
}.toDC
}

/**
* Filters the array elements using the provided condition.
Expand All @@ -188,10 +231,14 @@ private[syntax] trait ArrayColumns {
* @see org.apache.spark.sql.functions.filter
* @todo scaladoc link (issue #135)
*/
def filter(p: DoricColumn[T] => BooleanColumn): DoricColumn[F[T]] =
(col.elem, p(x).elem, p(col.getIndex(0)).elem)
.mapN((a, f, _) => new Column(ArrayFilter(a.expr, lam1(f.expr))))
def filter(p: DoricColumn[T] => BooleanColumn): DoricColumn[F[T]] = {
val xv = x(col.getIndex(0))
(col.elem, p(xv).elem, xv.elem)
.mapN((a, f, x) =>
new Column(ArrayFilter(a.expr, lam1(f.expr, x.expr)))
)
.toDC
}

/**
* Returns an array of elements for which a predicate holds in a given array.
Expand All @@ -209,12 +256,15 @@ private[syntax] trait ArrayColumns {
def filterWIndex(
function: (DoricColumn[T], IntegerColumn) => BooleanColumn
): ArrayColumn[T] = {
val xv = x(col.getIndex(0))
val yv = y(1.lit)
(
col.elem,
function(x, y).elem,
function(col.getIndex(0), 1.lit).elem
).mapN { (a, f, _) =>
new Column(ArrayFilter(a.expr, lam2(f.expr)))
function(xv, yv).elem,
xv.elem,
yv.elem
).mapN { (a, f, x, y) =>
new Column(ArrayFilter(a.expr, lam2(f.expr, x.expr, y.expr)))
}.toDC
}

Expand Down Expand Up @@ -411,12 +461,14 @@ private[syntax] trait ArrayColumns {
* @group Array Type
* @see [[org.apache.spark.sql.functions.exists]]
*/
def exists(fun: DoricColumn[T] => BooleanColumn): BooleanColumn =
(col.elem, fun(x).elem, fun(col.getIndex(0)).elem)
.mapN((c, f, _) => {
new Column(ArrayExists(c.expr, lam1(f.expr)))
def exists(fun: DoricColumn[T] => BooleanColumn): BooleanColumn = {
val xv = x(col.getIndex(0))
(col.elem, fun(xv).elem, xv.elem)
.mapN((c, f, x) => {
new Column(ArrayExists(c.expr, lam1(f.expr, x.expr)))
})
.toDC
}

/**
* Creates a new row for each element in the given array column.
Expand Down Expand Up @@ -495,13 +547,16 @@ private[syntax] trait ArrayColumns {
)(
col2: ArrayColumn[T2]
): ArrayColumn[O] = {
val xv = x(col.getIndex(0))
val yv = y(col2.getIndex(0))
(
col.elem,
col2.elem,
function(x, y).elem,
function(col.getIndex(0), col2.getIndex(0)).elem
).mapN { (a, b, f, _) =>
new Column(ZipWith(a.expr, b.expr, lam2(f.expr)))
function(xv, yv).elem,
xv.elem,
yv.elem
).mapN { (a, b, f, x, y) =>
new Column(ZipWith(a.expr, b.expr, lam2(f.expr, x.expr, y.expr)))
}.toDC
}
}
Expand Down
Loading

0 comments on commit fa3c0a5

Please sign in to comment.