Skip to content

Commit

Permalink
Capture all the errors using HOF with lists
Browse files Browse the repository at this point in the history
  • Loading branch information
alfonsorr committed Jun 9, 2022
1 parent 227732f commit d1b305d
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 38 deletions.
60 changes: 43 additions & 17 deletions core/src/main/scala/doric/syntax/ArrayColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ private[syntax] trait ArrayColumns {
def transform[A](
fun: DoricColumn[T] => DoricColumn[A]
): DoricColumn[F[A]] =
(col.elem, fun(x).elem)
.mapN((a, f) => new Column(ArrayTransform(a.expr, lam1(f.expr))))
(col.elem, fun(x).elem, fun(col.getIndex(0)).elem)
.mapN((a, f, _) => new Column(ArrayTransform(a.expr, lam1(f.expr))))
.toDC

/**
Expand All @@ -107,8 +107,9 @@ private[syntax] trait ArrayColumns {
def transformWithIndex[A](
fun: (DoricColumn[T], IntegerColumn) => DoricColumn[A]
): DoricColumn[F[A]] =
(col.elem, fun(x, y).elem).mapN { (a, f) =>
new Column(ArrayTransform(a.expr, lam2(f.expr)))
(col.elem, fun(x, y).elem, fun(col.getIndex(0), 1.lit).elem).mapN {
(a, f, _) =>
new Column(ArrayTransform(a.expr, lam2(f.expr)))
}.toDC

/**
Expand All @@ -135,9 +136,15 @@ private[syntax] trait ArrayColumns {
merge: (DoricColumn[A], DoricColumn[T]) => DoricColumn[A],
finish: DoricColumn[A] => DoricColumn[B]
): DoricColumn[B] =
(col.elem, zero.elem, merge(x, y).elem, finish(x).elem).mapN {
(a, z, m, f) =>
new Column(ArrayAggregate(a.expr, z.expr, lam2(m.expr), lam1(f.expr)))
(
col.elem,
zero.elem,
merge(x, y).elem,
finish(x).elem,
merge(zero, col.getIndex(0)).elem,
finish(zero).elem
).mapN { (a, z, m, f, _, _) =>
new Column(ArrayAggregate(a.expr, z.expr, lam2(m.expr), lam1(f.expr)))
}.toDC

/**
Expand Down Expand Up @@ -182,8 +189,8 @@ private[syntax] trait ArrayColumns {
* @todo scaladoc link (issue #135)
*/
def filter(p: DoricColumn[T] => BooleanColumn): DoricColumn[F[T]] =
(col.elem, p(x).elem)
.mapN((a, f) => new Column(ArrayFilter(a.expr, lam1(f.expr))))
(col.elem, p(x).elem, p(col.getIndex(0)).elem)
.mapN((a, f, _) => new Column(ArrayFilter(a.expr, lam1(f.expr))))
.toDC

/**
Expand All @@ -202,7 +209,11 @@ private[syntax] trait ArrayColumns {
def filterWIndex(
function: (DoricColumn[T], IntegerColumn) => BooleanColumn
): ArrayColumn[T] = {
(col.elem, function(x, y).elem).mapN { (a, f) =>
(
col.elem,
function(x, y).elem,
function(col.getIndex(0), 1.lit).elem
).mapN { (a, f, _) =>
new Column(ArrayFilter(a.expr, lam2(f.expr)))
}.toDC
}
Expand Down Expand Up @@ -401,8 +412,8 @@ private[syntax] trait ArrayColumns {
* @see [[org.apache.spark.sql.functions.exists]]
*/
def exists(fun: DoricColumn[T] => BooleanColumn): BooleanColumn =
(col.elem, fun(x).elem)
.mapN((c, f) => {
(col.elem, fun(x).elem, fun(col.getIndex(0)).elem)
.mapN((c, f, _) => {
new Column(ArrayExists(c.expr, lam1(f.expr)))
})
.toDC
Expand Down Expand Up @@ -479,13 +490,28 @@ private[syntax] trait ArrayColumns {
* @group Array Type
* @see [[org.apache.spark.sql.functions.zip_with]]
*/
def zipWith(
col2: ArrayColumn[T],
function: (DoricColumn[T], DoricColumn[T]) => DoricColumn[T]
): ArrayColumn[T] = {
(col.elem, col2.elem, function(x, y).elem).mapN { (a, b, f) =>
def zipWith[T2, O](
function: (DoricColumn[T], DoricColumn[T2]) => DoricColumn[O]
)(
col2: ArrayColumn[T2]
): ArrayColumn[O] = {
(
col.elem,
col2.elem,
function(x, y).elem,
function(col.getIndex(0), col2.getIndex(0)).elem
).mapN { (a, b, f, _) =>
new Column(ZipWith(a.expr, b.expr, lam2(f.expr)))
}.toDC
}
}

implicit class ArrayArrayColumnSyntax[T, F[_]: CollectionType, G[
_
]: CollectionType](
private val col: DoricColumn[F[G[T]]]
) {
def flatten: DoricColumn[F[T]] =
col.elem.map(f.flatten).toDC
}
}
18 changes: 9 additions & 9 deletions core/src/main/scala/doric/syntax/DStructs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import doric.sem.{ChildColumnNotFound, ColumnTypeError, DoricSingleError, Locati
import doric.types.SparkType

import org.apache.spark.sql.{Column, Dataset, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable
import org.apache.spark.sql.catalyst.expressions.{Expression, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.functions.{struct => sparkStruct}
import org.apache.spark.sql.types.StructType

Expand All @@ -39,6 +38,13 @@ private[syntax] trait DStructs {
def struct(cols: DoricColumn[_]*): RowColumn =
cols.map(_.elem).toList.sequence.map(c => sparkStruct(c: _*)).toDC

private def isLambda(expression: Expression): Boolean = {
expression match {
case _: UnresolvedNamedLambdaVariable => true
case _ => expression.children.exists(isLambda)
}
}

implicit class DStructOps(private val col: RowColumn) {

/**
Expand All @@ -59,13 +65,7 @@ private[syntax] trait DStructs {
col.elem
.mapK(toEither)
.flatMap(vcolumn =>
if (
vcolumn.expr match {
case _: UnresolvedNamedLambdaVariable => true
case _: UnresolvedExtractValue => true
case _ => false
}
)
if (isLambda(vcolumn.expr))
Kleisli[DoricEither, Dataset[_], Column]((_: Dataset[_]) => {
Right(vcolumn(subColumnName))
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ trait MapColumns3x {
def filter(
function: (DoricColumn[K], DoricColumn[V]) => BooleanColumn
): MapColumn[K, V] = {
(map.elem, function(x, y).elem).mapN { (a, f) =>
(
map.elem,
function(x, y).elem,
function(map.keys.getIndex(0), map.values.getIndex(0)).elem
).mapN { (a, f, _) =>
new Column(MapFilter(a.expr, lam2(f.expr)))
}.toDC
}
Expand All @@ -53,7 +57,16 @@ trait MapColumns3x {
DoricColumn[V2]
) => DoricColumn[R]
): MapColumn[K, R] = {
(map.elem, map2.elem, function(x, y, z).elem).mapN { (a, b, f) =>
(
map.elem,
map2.elem,
function(x, y, z).elem,
function(
map.keys.getIndex(0),
map.values.getIndex(0),
map2.values.getIndex(0)
).elem
).mapN { (a, b, f, _) =>
new Column(MapZipWith(a.expr, b.expr, lam3(f.expr)))
}.toDC
}
Expand All @@ -72,7 +85,11 @@ trait MapColumns3x {
def transformKeys[K2](
function: (DoricColumn[K], DoricColumn[V]) => DoricColumn[K2]
): MapColumn[K2, V] = {
(map.elem, function(x, y).elem).mapN { (a, f) =>
(
map.elem,
function(x, y).elem,
function(map.keys.getIndex(0), map.values.getIndex(0)).elem
).mapN { (a, f, _) =>
new Column(TransformKeys(a.expr, lam2(f.expr)))
}.toDC
}
Expand All @@ -91,7 +108,11 @@ trait MapColumns3x {
def transformValues[V2](
function: (DoricColumn[K], DoricColumn[V]) => DoricColumn[V2]
): MapColumn[K, V2] = {
(map.elem, function(x, y).elem).mapN { (a, f) =>
(
map.elem,
function(x, y).elem,
function(map.keys.getIndex(0), map.values.getIndex(0)).elem
).mapN { (a, f, _) =>
new Column(TransformValues(a.expr, lam2(f.expr)))
}.toDC
}
Expand Down
70 changes: 64 additions & 6 deletions core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package doric
package syntax

import doric.sem.{ColumnTypeError, DoricMultiError, SparkErrorWrapper}
import doric.sem.{ChildColumnNotFound, ColumnTypeError, DoricMultiError, SparkErrorWrapper}

import org.apache.spark.sql.{Row, functions => f}
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType}

class ArrayColumnsSpec extends DoricTestElements {

Expand Down Expand Up @@ -58,21 +58,79 @@ class ArrayColumnsSpec extends DoricTestElements {
val df2 = List((List((1, "a"), (2, "b"), (3, "c")), 7))
.toDF(testColumn, "something")

intercept[DoricMultiError] {
df2
.select(col[Array[Row]](testColumn).getIndex(0).getChild[Int]("_1"))
.show()

val errors = intercept[DoricMultiError] {
df2.select(
colArray[Row](testColumn)
.transform(_.getChild[Int]("_3") + colInt("something2")),
colArray[Row](testColumn)
.transform(_.getChild[Long]("_1") + colInt("something2").cast)
.transform(_.getChild[Long]("_1") + colInt("something").cast)
)
} should containAllErrors(
}
errors.errors.toNonEmptyList.toList.foreach(println)
errors should containAllErrors(
SparkErrorWrapper(
new Exception(
"Cannot resolve column name \"something2\" among (col, something)"
)
),
ColumnTypeError("something", StringType, IntegerType)
ColumnTypeError("col[0]._1", LongType, IntegerType),
ChildColumnNotFound("_3", List("_1", "_2"))
)
}

it(
"should work with complex types that mix Row and Array and return errors if needed"
) {

val df3 = List((List(List((1, "a"), (2, "b"), (3, "c"))), 7))
.toDF(testColumn, "something")

df3.printSchema()
df3.select(
col[Array[Array[Row]]](testColumn)
.transform(_.transform(_.getChild[Int]("_1")))
)

df3.select(
col[Array[Array[Row]]](testColumn)
.transform(_.getIndex(0).getChild[Int]("_1"))
)

intercept[DoricMultiError] {
df3.select(
col[Array[Array[Row]]](testColumn)
.transform(_.transform(_.getChild[Int]("_3"))),
col[Array[Array[Row]]](testColumn)
.transform(_.transform(_.getChild[Long]("_1")))
)
} should containAllErrors(
ChildColumnNotFound("_3", List("_1", "_2")),
ColumnTypeError("col[0][0]._1", LongType, IntegerType)
)

val value: List[(List[(Int, String)], Long)] = List((List((1, "a")), 10L))
val df4 = List((value, 7))
.toDF(testColumn, "something")

val colTransform = col[Array[Row]](testColumn)
.transform(
_.getChild[Array[Row]]("_1").transform(_.getChild[Int]("_1"))
)
.flatten as "l"
val colTransform2 = col[Array[Row]](testColumn)
.transform(
_.getChild[Array[Row]]("_1")
)
.flatten as "l"
df4
.select(
colTransform.zipWith[Row, Row]((a, b) => struct(a, b))(colTransform2)
)
.transform(df => { df.printSchema(); df.show(false); df })
}

it(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ class ArrayColumns3xSpec
).toDF("col1", "col2")

df.testColumns2("col1", "col2")(
(c1, c2) => colArrayString(c1).zipWith(col(c2), concat(_, _)),
(
c1,
c2
) =>
colArrayString(c1).zipWith[String, String](concat(_, _))(
colArrayString(c2)
),
(c1, c2) => f.zip_with(f.col(c1), f.col(c2), f.concat(_, _)),
List(Some(Array("ab", "ba", "ce", null)), None, None, None)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package doric
package syntax

import doric.implicitConversions._
import doric.sem.{ChildColumnNotFound, ColumnTypeError, DoricMultiError}

import org.apache.spark.sql.{functions => f}
import org.apache.spark.sql.{Row, functions => f}
import org.apache.spark.sql.types.{IntegerType, StringType}

class MapColumns3xSpec extends DoricTestElements with MapColumns {

Expand Down Expand Up @@ -101,6 +103,22 @@ class MapColumns3xSpec extends DoricTestElements with MapColumns {
null
).map(Option(_))
)

intercept[DoricMultiError] {
List(
Map("k1" -> ("v1", 1), "k2" -> ("v2", 2)),
Map.empty[String, (String, Int)],
null
).toDF("col1")
.select(
col[Map[String, Row]]("col1").transformValues((a, b) =>
a + b.getChild("_3") + b.getChild[Int]("_1").cast
)
)
} should containAllErrors(
ChildColumnNotFound("_3", List("_1", "_2")),
ColumnTypeError("map_values(col1)[0]._1", IntegerType, StringType)
)
}
}
}

0 comments on commit d1b305d

Please sign in to comment.