diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index 0322a248..0bbaf6fe 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -2,9 +2,11 @@ package frameless import frameless.functions.{litAggr, lit => flit} import frameless.syntax._ + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.DecimalType import org.apache.spark.sql.{Column, FramelessInternals} + import shapeless._ import shapeless.ops.record.Selector @@ -27,9 +29,8 @@ sealed class TypedColumn[T, U](expr: Expression)( type ThisType[A, B] = TypedColumn[A, B] - def this(column: Column)(implicit uencoder: TypedEncoder[U]) = { + def this(column: Column)(implicit uencoder: TypedEncoder[U]) = this(FramelessInternals.expr(column)) - } override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn @@ -140,8 +141,9 @@ abstract class AbstractTypedColumn[T, U] equalsTo(other) /** Inequality test. + * * {{{ - * df.filter( df.col('a) =!= df.col('b) ) + * df.filter(df.col('a) =!= df.col('b)) * }}} * * apache/spark @@ -150,28 +152,28 @@ abstract class AbstractTypedColumn[T, U] typed(Not(equalsTo(other).expr)) /** Inequality test. + * * {{{ - * df.filter( df.col('a) =!= "a" ) + * df.filter(df.col('a) =!= "a") * }}} * * apache/spark */ - def =!=(u: U): ThisType[T, Boolean] = - typed(Not(equalsTo(lit(u)).expr)) + def =!=(u: U): ThisType[T, Boolean] = typed(Not(equalsTo(lit(u)).expr)) /** True if the current expression is an Option and it's None. * * apache/spark */ def isNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] = - equalsTo[T, T](lit[U](None.asInstanceOf[U])) + typed(IsNull(expr)) /** True if the current expression is an Option and it's not None. * * apache/spark */ def isNotNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] = - typed(Not(equalsTo(lit(None.asInstanceOf[U])).expr)) + typed(IsNotNull(expr)) /** True if the current expression is a fractional number and is not NaN. * @@ -180,15 +182,43 @@ abstract class AbstractTypedColumn[T, U] def isNaN(implicit n: CatalystNaN[U]): ThisType[T, Boolean] = typed(self.untyped.isNaN) - /** Convert an Optional column by providing a default value + /** + * True if the value for this optional column `exists` as expected + * (see `Option.exists`). + * + * {{{ + * df.col('opt).isSome(_ === someOtherCol) + * }}} + */ + def isSome[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, false) + + /** + * True if the value for this optional column `exists` as expected, + * or is `None`. (see `Option.forall`). + * + * {{{ + * df.col('opt).isSomeOrNone(_ === someOtherCol) + * }}} + */ + def isSomeOrNone[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, true) + + private def someOr[V](exists: ThisType[T, V] => ThisType[T, Boolean], default: Boolean)(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = { + val defaultExpr = if (default) Literal.TrueLiteral else Literal.FalseLiteral + + typed(Coalesce(Seq(opt(i0).map(exists).expr, defaultExpr))) + } + + /** Convert an Optional column by providing a default value. + * * {{{ - * df( df('opt).getOrElse(df('defaultValue)) ) + * df(df('opt).getOrElse(df('defaultValue))) * }}} */ def getOrElse[TT, W, Out](default: ThisType[TT, Out])(implicit i0: U =:= Option[Out], i1: With.Aux[T, TT, W]): ThisType[W, Out] = typed(Coalesce(Seq(expr, default.expr)))(default.uencoder) - /** Convert an Optional column by providing a default value + /** Convert an Optional column by providing a default value. + * * {{{ * df( df('opt).getOrElse(defaultConstant) ) * }}} @@ -197,6 +227,7 @@ abstract class AbstractTypedColumn[T, U] getOrElse(lit[Out](default)) /** Sum of this expression and another expression. + * * {{{ * // The following selects the sum of a person's height and weight. * people.select( people.col('height) plus people.col('weight) ) @@ -700,9 +731,10 @@ abstract class AbstractTypedColumn[T, U] or(other) /** Less than. + * * {{{ - * // The following selects people younger than the maxAge column. - * df.select( df('age) < df('maxAge) ) + * // The following selects people younger than the maxAge column. + * df.select(df('age) < df('maxAge) ) * }}} * * @param other another column of the same type @@ -712,9 +744,10 @@ abstract class AbstractTypedColumn[T, U] typed(self.untyped < other.untyped) /** Less than or equal to. + * * {{{ - * // The following selects people younger or equal than the maxAge column. - * df.select( df('age) <= df('maxAge) + * // The following selects people younger or equal than the maxAge column. + * df.select(df('age) <= df('maxAge) * }}} * * @param other another column of the same type diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index 4104f0db..4e5fafc9 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -36,6 +36,7 @@ package object functions extends Udf with UnaryFunctions { if (ScalaReflection.isNativeType(encoder.jvmRepr) && encoder.catalystRepr == encoder.jvmRepr) { val expr = Literal(value, encoder.catalystRepr) + new TypedColumn(expr) } else { val expr = new Literal(value, encoder.jvmRepr) diff --git a/dataset/src/test/scala/frameless/FilterTests.scala b/dataset/src/test/scala/frameless/FilterTests.scala index 3f122cf4..56d5d2ec 100644 --- a/dataset/src/test/scala/frameless/FilterTests.scala +++ b/dataset/src/test/scala/frameless/FilterTests.scala @@ -1,9 +1,11 @@ package frameless +import org.scalatest.matchers.should.Matchers + import org.scalacheck.Prop import org.scalacheck.Prop._ -class FilterTests extends TypedDatasetSuite { +final class FilterTests extends TypedDatasetSuite with Matchers { test("filter('a == lit(b))") { def prop[A: TypedEncoder](elem: A, data: Vector[X1[A]])(implicit ex1: TypedEncoder[X1[A]]): Prop = { val dataset = TypedDataset.create(data) @@ -145,6 +147,24 @@ class FilterTests extends TypedDatasetSuite { check(forAll(prop[Option[X1[X1[Vector[Option[Int]]]]]] _)) } + test("Option content filter") { + val data = (Option(1L), Option(2L)) :: (Option(0L), Option(1L)) :: (None, None) :: Nil + + val ds = TypedDataset.create(data) + + val l = functions.lit[Long, (Option[Long], Option[Long])](0L) + val exists = ds('_1).isSome[Long](_ <= l) + val forall = ds('_1).isSomeOrNone[Long](_ <= l) + + ds.select(exists).collect().run() shouldEqual Seq(false, true, false) + ds.select(forall).collect().run() shouldEqual Seq(false, true, true) + + ds.filter(exists).collect().run() shouldEqual Seq(Option(0L) -> Option(1L)) + + ds.filter(forall).collect().run() shouldEqual Seq( + Option(0L) -> Option(1L), (None -> None)) + } + test("filter with isin values") { def prop[A: TypedEncoder](data: Vector[X1[A]], values: Vector[A])(implicit a : CatalystIsin[A]): Prop = { val ds = TypedDataset.create(data)