diff --git a/core/src/main/scala/doric/sem/All.scala b/core/src/main/scala/doric/sem/All.scala index 5521af41d..01eff9dea 100644 --- a/core/src/main/scala/doric/sem/All.scala +++ b/core/src/main/scala/doric/sem/All.scala @@ -1,3 +1,8 @@ package doric.sem -trait All extends AggregationOps with TransformOps with JoinOps with CollectOps +trait All + extends AggregationOps + with TransformOps + with JoinOps + with CollectOps + with SortingOps diff --git a/core/src/main/scala/doric/sem/SortingOps.scala b/core/src/main/scala/doric/sem/SortingOps.scala new file mode 100644 index 000000000..58bc4420c --- /dev/null +++ b/core/src/main/scala/doric/sem/SortingOps.scala @@ -0,0 +1,27 @@ +package doric.sem + +import cats.implicits._ +import doric.DoricColumn +import org.apache.spark.sql.DataFrame + +private[sem] trait SortingOps { + + implicit class DataframeSortSyntax(df: DataFrame) { + + def sort(col: DoricColumn[_]*): DataFrame = + col.toList + .traverse(_.elem) + .run(df) + .map(col => df.sort(col: _*)) + .returnOrThrow("sort") + + def orderBy(col: DoricColumn[_]*): DataFrame = sort(col: _*) + + def sortWithinPartitions(col: DoricColumn[_]*): DataFrame = + col.toList + .traverse(_.elem) + .run(df) + .map(col => df.sortWithinPartitions(col: _*)) + .returnOrThrow("sortWithPartitions") + } +} diff --git a/core/src/main/scala/doric/syntax/CommonColumns.scala b/core/src/main/scala/doric/syntax/CommonColumns.scala index 14a6a76e3..7a42d56a6 100644 --- a/core/src/main/scala/doric/syntax/CommonColumns.scala +++ b/core/src/main/scala/doric/syntax/CommonColumns.scala @@ -215,6 +215,57 @@ private[syntax] trait CommonColumns extends ColGetters[NamedDoricColumn] { }) .toDC + /** + * Sorts a column in ascending order + * @return + * A DoricColumn of the provided type + * @see [[org.apache.spark.sql.Column.asc]] + */ + def asc: DoricColumn[T] = column.elem.map(col => col.asc).toDC + + /** + * Sorts a column in ascending order with null values returned before non-null values + * @return + * A DoricColumn of the provided type + * @see [[org.apache.spark.sql.Column.asc_nulls_first]] + */ + def ascNullsFirst: DoricColumn[T] = + column.elem.map(col => col.asc_nulls_first).toDC + + /** + * Sorts a column in ascending order with null values returned after non-null values + * @return + * A DoricColumn of the provided type + * @see [[org.apache.spark.sql.Column.asc_nulls_last]] + */ + def ascNullsLast: DoricColumn[T] = + column.elem.map(col => col.asc_nulls_last).toDC + + /** + * Sorts a column in descending order + * @return + * A DoricColumn of the provided type + * @see [[org.apache.spark.sql.Column.desc]] + */ + def desc: DoricColumn[T] = column.elem.map(col => col.desc).toDC + + /** + * Sorts a column in descending order with null values returned before non-null values + * @return + * A DoricColumn of the provided type + * @see [[org.apache.spark.sql.Column.desc_nulls_first]] + */ + def descNullsFirst: DoricColumn[T] = + column.elem.map(col => col.desc_nulls_first).toDC + + /** + * Sorts a column in descending order with null values returned after non-null values + * @return + * A DoricColumn of the provided type + * @see [[org.apache.spark.sql.Column.desc_nulls_last]] + */ + def descNullsLast: DoricColumn[T] = + column.elem.map(col => col.desc_nulls_last).toDC } } diff --git a/core/src/test/scala/doric/sem/SortingOpsSpec.scala b/core/src/test/scala/doric/sem/SortingOpsSpec.scala new file mode 100644 index 000000000..dd1e8117f --- /dev/null +++ b/core/src/test/scala/doric/sem/SortingOpsSpec.scala @@ -0,0 +1,54 @@ +package doric.sem + +import doric.{DoricTestElements, colInt, colString} +import org.apache.spark.sql.Row + +class SortingOpsSpec extends DoricTestElements { + import spark.implicits._ + + describe("Sort") { + it("sorts a dataframe with sort function on one column") { + val df = List((1, "a"), (2, "b"), (3, "c"), (4, "d")).toDF("col1", "col2") + + val res = df.sort(colInt("col1").desc) + val actual = List(Row(4, "d"), Row(3, "c"), Row(2, "b"), Row(1, "a")) + + res.collect().toList should contain theSameElementsInOrderAs actual + } + + it("sorts a dataframe with sort function on multiple columns") { + val df = List((1, "z"), (2, "n"), (3, "x"), (2, "f")).toDF("col1", "col2") + + val res = df.sort(colInt("col1").desc, colString("col2").asc) + val actual = List(Row(3, "x"), Row(2, "f"), Row(2, "n"), Row(1, "z")) + + res.collect().toList should contain theSameElementsInOrderAs actual + } + } + + describe("Sort Within Partitions") { + it("sorts dataframe partitions with sort function on one column") { + val df = List((1, "a"), (2, "b"), (3, "c"), (4, "d")) + .toDF("col1", "col2") + .repartition(2) + + val res = df.sortWithinPartitions(colInt("col1").asc) + val actual = List(Row(3, "c"), Row(4, "d"), Row(1, "a"), Row(2, "b")) + + res.collect().toList should contain theSameElementsInOrderAs actual + } + + it("sorts dataframe partitions with sort function on multiple columns") { + val df = List((1, "z"), (2, "n"), (3, "x"), (2, "f"), (2, "z")) + .toDF("col1", "col2") + .repartition(2) + + val res = + df.sortWithinPartitions(colInt("col1").desc, colString("col2").asc) + val actual = + List(Row(2, "n"), Row(2, "z"), Row(1, "z"), Row(3, "x"), Row(2, "f")) + + res.collect().toList should contain theSameElementsInOrderAs actual + } + } +} diff --git a/core/src/test/scala/doric/syntax/CommonColumnsSpec.scala b/core/src/test/scala/doric/syntax/CommonColumnsSpec.scala index 36a0a1051..b90e16290 100644 --- a/core/src/test/scala/doric/syntax/CommonColumnsSpec.scala +++ b/core/src/test/scala/doric/syntax/CommonColumnsSpec.scala @@ -318,4 +318,97 @@ class CommonColumnsSpec } } + describe("asc doric function") { + import spark.implicits._ + + it("should sort a df in ascending order") { + val df = List(5, 4, 3, 2, 1) + .toDF("col1") + + val res = df.orderBy(colInt("col1").asc).as[Int].collect().toList + res shouldBe List(1, 2, 3, 4, 5) + } + + it("should sort a df in ascending order for more complex types") { + val df = List(List(5, 6), List(4, 4, 5), List(3), List(1, 2), List(1)) + .toDF("col1") + + val res = + df.orderBy(colArrayInt("col1").asc).as[List[Int]].collect().toList + res shouldBe List(List(1), List(1, 2), List(3), List(4, 4, 5), List(5, 6)) + } + } + + describe("ascNullsFirst doric function") { + import spark.implicits._ + + it( + "should sort a df in ascending order with null values returned before non-nulls" + ) { + val df = List("5", "4", null, "3", "2", null, "1") + .toDF("col1") + + val res = + df.orderBy(colString("col1").ascNullsFirst).as[String].collect().toList + res shouldBe List(null, null, "1", "2", "3", "4", "5") + } + } + + describe("ascNullsLast doric function") { + import spark.implicits._ + + it( + "should sort a df in ascending order with null values returned after non-nulls" + ) { + val df = List("5", "4", null, "3", "2", null, "1") + .toDF("col1") + + val res = + df.orderBy(colString("col1").ascNullsLast).as[String].collect().toList + res shouldBe List("1", "2", "3", "4", "5", null, null) + } + } + + describe("desc doric function") { + import spark.implicits._ + + it("should sort a df in descending order") { + val df = List(1, 2, 3, 4, 5) + .toDF("col1") + + val res = df.orderBy(colInt("col1").desc).as[Int].collect().toList + res shouldBe List(5, 4, 3, 2, 1) + } + } + + describe("descNullsFirst doric function") { + import spark.implicits._ + + it( + "should sort a df in descending order with null values returned before non-nulls" + ) { + val df = List("1", "2", null, null, "5", "3", null, "4") + .toDF("col1") + + val res = + df.orderBy(colString("col1").descNullsFirst).as[String].collect().toList + res shouldBe List(null, null, null, "5", "4", "3", "2", "1") + } + } + + describe("descNullsLast doric function") { + import spark.implicits._ + + it( + "should sort a df in descending order with null values returned after non-nulls" + ) { + val df = List("1", "2", null, null, "5", "3", null, "4") + .toDF("col1") + + val res = + df.orderBy(colString("col1").descNullsLast).as[String].collect().toList + res shouldBe List("5", "4", "3", "2", "1", null, null, null) + } + } + } diff --git a/core/src/test/scala/doric/syntax/DateColumnsSpec.scala b/core/src/test/scala/doric/syntax/DateColumnsSpec.scala index b6f7f8e06..ef68f2930 100644 --- a/core/src/test/scala/doric/syntax/DateColumnsSpec.scala +++ b/core/src/test/scala/doric/syntax/DateColumnsSpec.scala @@ -6,6 +6,7 @@ import java.time.{Instant, LocalDate} import org.scalatest.EitherValues import org.scalatest.matchers.should.Matchers import org.apache.spark.sql.{DataFrame, functions => f} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.dateAddMonths class DateColumnsSpec extends DoricTestElements @@ -63,10 +64,42 @@ class DateColumnsSpec } it("should subtract months if num months < 0 with literal") { + val expectedDate = + if (spark.version.take(1).toInt > 2) + LocalDate.now().minusMonths(3) + else { + val additionalDays = + dateAddMonths(LocalDate.now().toEpochDay.asInstanceOf[Int], -3) + LocalDate.of(1970, 1, 1).plusDays(additionalDays.toLong) + } + + df.testColumns2("dateCol", -3)( + (d, m) => colDate(d).addMonths(m.lit), + (d, m) => f.add_months(f.col(d), m), + List(Date.valueOf(expectedDate), null).map(Option(_)) + ) + } + + it( + "should correctly subtract months if num months < 0 with literal for end of month dates" + ) { + val localDate = LocalDate.of(2022, 6, 30) + + val df = List(Date.valueOf(localDate), null).toDF("dateCol") + + val expectedDate = + if (spark.version.take(1).toInt > 2) + localDate.minusMonths(3) + else { + val additionalDays = + dateAddMonths(localDate.toEpochDay.asInstanceOf[Int], -3) + LocalDate.of(1970, 1, 1).plusDays(additionalDays.toLong) + } + df.testColumns2("dateCol", -3)( (d, m) => colDate(d).addMonths(m.lit), (d, m) => f.add_months(f.col(d), m), - List(Date.valueOf(LocalDate.now.minusMonths(3)), null).map(Option(_)) + List(Date.valueOf(expectedDate), null).map(Option(_)) ) } }