From a081649fc95d96e551e68707a22b2b008f972954 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 8 Feb 2024 14:56:18 +0100 Subject: [PATCH 01/46] initial working version --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 165 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 40 +++++ 3 files changed, 206 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b165d20d0b4fa..18f4eadc0cec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -696,6 +696,7 @@ object FunctionRegistry { expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), expression[MapConcat]("map_concat"), + expression[SortMap]("sort_map"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality", true), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a090bdf2bebf6..8088c97ded3b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -888,6 +888,171 @@ case class MapFromEntries(child: Expression) copy(child = newChild) } +@ExpressionDescription( + usage = """ + _FUNC_(map[, ascendingOrder]) - Sorts the input map in ascending or descending order + according to the natural ordering of the map keys. + """, + examples = """ + Examples: + > SELECT _FUNC_(map(3, 'c', 1, 'a', 2, 'b'), true); + {1:"a",2:"b",3:"c"} + """, + group = "map_funcs", + since = "4.0.0") +case class SortMap(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with NullIntolerant with QueryErrorsBase { + + def this(e: Expression) = this(e, Literal(true)) + + val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType + val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case MapType(kt, _, _) if RowOrdering.isOrderable(kt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "2", + "requiredType" -> toSQLType(BooleanType), + "inputSql" -> toSQLExpr(ascendingOrder), + "inputType" -> toSQLType(ascendingOrder.dataType)) + ) + } + case MapType(_, _, _) => + DataTypeMismatch( + errorSubClass = "INVALID_ORDERING_TYPE", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(base.dataType) + ) + ) + case _ => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(base), + "inputType" -> toSQLType(base.dataType)) + ) + } + + override def nullSafeEval(array: Any, ascending: Any): Any = { + // put keys in a tree map and then read them back to build new k/v arrays + + val mapData = array.asInstanceOf[MapData] + val numElements = mapData.numElements() + val keys = mapData.keyArray() + val values = mapData.valueArray() + + val ordering = if (ascending.asInstanceOf[Boolean]) { + PhysicalDataType.ordering(keyType) + } else { + PhysicalDataType.ordering(keyType).reverse + } + + val treeMap = mutable.TreeMap.empty[Any, Int](ordering) + for (i <- 0 until numElements) { + treeMap.put(keys.get(i, keyType), i) + } + + val newKeys = new Array[Any](numElements) + val newValues = new Array[Any](numElements) + + treeMap.zipWithIndex.foreach { case ((_, originalIndex), sortedIndex) => + newKeys(sortedIndex) = keys.get(originalIndex, keyType) + newValues(sortedIndex) = values.get(originalIndex, valueType) + } + + new ArrayBasedMapData(new GenericArrayData(newKeys), new GenericArrayData(newValues)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + } + + private def sortCodegen(ctx: CodegenContext, ev: ExprCode, + base: String, order: String): String = { + + val arrayBasedMapData = classOf[ArrayBasedMapData].getName + val genericArrayData = classOf[GenericArrayData].getName + + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val treeMap = ctx.freshName("treeMap") + val i = ctx.freshName("i") + val o1 = ctx.freshName("o1") + val o2 = ctx.freshName("o2") + val c = ctx.freshName("c") + val newKeys = ctx.freshName("newKeys") + val newValues = ctx.freshName("newValues") + val mapEntry = ctx.freshName("mapEntry") + val originalIndex = ctx.freshName("originalIndex") + + val boxedKeyType = CodeGenerator.boxedType(keyType) + val javaKeyType = CodeGenerator.javaType(keyType) + + val comp = if (CodeGenerator.isPrimitiveType(keyType)) { + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$javaKeyType $v1 = (($boxedKeyType) $o1).${javaKeyType}Value(); + |$javaKeyType $v2 = (($boxedKeyType) $o2).${javaKeyType}Value(); + |int $c = ${ctx.genComp(keyType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(keyType, s"(($javaKeyType) $o1)", s"(($javaKeyType) $o2)")};" + } + + s""" + |final int $numElements = $base.numElements(); + |ArrayData $keys = $base.keyArray(); + |ArrayData $values = $base.valueArray(); + | + |java.util.TreeMap<$boxedKeyType, Integer> $treeMap = new java.util.TreeMap<>( + | new java.util.Comparator() { + | @Override public int compare(Object $o1, Object $o2) { + | $comp; + | return $order ? $c : -$c; + | } + | } + |); + | + |for (int $i = 0; $i < $numElements; $i++) { + | $treeMap.put(${CodeGenerator.getValue(keys, keyType, i)}, $i); + |} + | + |Object[] $newKeys = new Object[$numElements]; + |Object[] $newValues = new Object[$numElements]; + | + |int $i = 0; + |for (java.util.Map.Entry<$boxedKeyType, Integer> $mapEntry : $treeMap.entrySet()) { + | int $originalIndex = (Integer) $mapEntry.getValue(); + | $newKeys[$i] = ${CodeGenerator.getValue(keys, keyType, originalIndex)}; + | $newValues[$i] = ${CodeGenerator.getValue(values, valueType, originalIndex)}; + | $i++; + |} + | + |${ev.value} = new $arrayBasedMapData( + | new $genericArrayData($newKeys), new $genericArrayData($newValues)); + |""".stripMargin + } + + override def prettyName: String = "sort_map" + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression) + : SortMap = copy(base = newLeft, ascendingOrder = newRight) +} /** * Common base class for [[SortArray]] and [[ArraySort]]. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 133e27c5b0a66..302703d4497a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -421,6 +421,46 @@ class CollectionExpressionsSuite ) } + test("Sort Map") { + val intKey = Literal.create(Map(2 -> 2, 1 -> 1, 3 -> 3), MapType(IntegerType, IntegerType)) + val boolKey = Literal.create(Map(true -> 2, false -> 1), MapType(BooleanType, IntegerType)) + val stringKey = Literal.create(Map("2" -> 2, "1" -> 1, "3" -> 3), + MapType(StringType, IntegerType)) + val arrayKey = Literal.create(Map(Seq(2) -> 2, Seq(1) -> 1, Seq(3) -> 3), + MapType(ArrayType(IntegerType), IntegerType)) + val nestedArrayKey = Literal.create(Map(Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1, Seq(Seq(3)) -> 3), + MapType(ArrayType(ArrayType(IntegerType)), IntegerType)) + val structKey = Literal.create( + Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3), + MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType)) + + checkEvaluation(new SortMap(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) + checkEvaluation(SortMap(intKey, Literal.create(false, BooleanType)), + Map(3 -> 3, 2 -> 2, 1 -> 1)) + + checkEvaluation(new SortMap(boolKey), Map(false -> 1, true -> 2)) + checkEvaluation(SortMap(boolKey, Literal.create(false, BooleanType)), + Map(true -> 2, false -> 1)) + + checkEvaluation(new SortMap(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) + checkEvaluation(SortMap(stringKey, Literal.create(false, BooleanType)), + Map("3" -> 3, "2" -> 2, "1" -> 1)) + + checkEvaluation(new SortMap(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) + checkEvaluation(SortMap(arrayKey, Literal.create(false, BooleanType)), + Map(Seq(3) -> 3, Seq(2) -> 2, Seq(1) -> 1)) + + checkEvaluation(new SortMap(nestedArrayKey), + Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3)) + checkEvaluation(SortMap(nestedArrayKey, Literal.create(false, BooleanType)), + Map(Seq(Seq(3)) -> 3, Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1)) + + checkEvaluation(new SortMap(structKey), + Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3)) + checkEvaluation(SortMap(structKey, Literal.create(false, BooleanType)), + Map(create_row(3) -> 3, create_row(2) -> 2, create_row(1) -> 1)) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) From 1441549ed1fbee0188c32a6f3c44cb05d2e3470d Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sun, 11 Feb 2024 01:20:34 +0100 Subject: [PATCH 02/46] add golden files --- .../src/test/resources/sql-functions/sql-expression-schema.md | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index e20db3b49589c..4714e4f70668b 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -301,6 +301,7 @@ | org.apache.spark.sql.catalyst.expressions.Size | size | SELECT size(array('b', 'd', 'c', 'a')) | struct | | org.apache.spark.sql.catalyst.expressions.Slice | slice | SELECT slice(array(1, 2, 3, 4), 2, 2) | struct> | | org.apache.spark.sql.catalyst.expressions.SortArray | sort_array | SELECT sort_array(array('b', 'd', null, 'c', 'a'), true) | struct> | +| org.apache.spark.sql.catalyst.expressions.SortMap | sort_map | SELECT sort_map(map(3, 'c', 1, 'a', 2, 'b'), true) | struct> | | org.apache.spark.sql.catalyst.expressions.SoundEx | soundex | SELECT soundex('Miller') | struct | | org.apache.spark.sql.catalyst.expressions.SparkPartitionID | spark_partition_id | SELECT spark_partition_id() | struct | | org.apache.spark.sql.catalyst.expressions.SparkVersion | version | SELECT version() | struct | From 1be06e37e26b27e9f2e66bdc8260cb9d8abf9d81 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 14 Feb 2024 15:27:50 +0100 Subject: [PATCH 03/46] add map sort to other languages --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 17 +++++++ R/pkg/R/generics.R | 4 ++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 +++ .../org/apache/spark/sql/functions.scala | 19 ++++++++ .../spark/sql/PlanGenerationTestSuite.scala | 4 ++ .../reference/pyspark.sql/functions.rst | 1 + .../pyspark/sql/connect/functions/builtin.py | 7 +++ python/pyspark/sql/functions/builtin.py | 47 +++++++++++++++++++ python/pyspark/sql/tests/test_functions.py | 7 +++ .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/collectionOperations.scala | 6 +-- .../CollectionExpressionsSuite.scala | 24 +++++----- .../org/apache/spark/sql/functions.scala | 7 +++ .../sql-functions/sql-expression-schema.md | 2 +- 15 files changed, 137 insertions(+), 17 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c5668d1739b17..bdbcfa552448b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -361,6 +361,7 @@ exportMethods("%<=>%", "map_keys", "map_values", "map_zip_with", + "map_sort", "max", "max_by", "md5", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 5106a83bd0ec4..e3452d71682cc 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4523,6 +4523,23 @@ setMethod("map_zip_with", ) }) +#' @details +#' \code{sort_array}: Sorts the input map in ascending or descending order according to +#' the natural ordering of the map keys. +#' +#' @rdname column_collection_functions +#' @param asc a logical flag indicating the sorting order. +#' TRUE, sorting is in ascending order. +#' FALSE, sorting is in descending order. +#' @aliases map_sort map_sort,Column-method +#' @note sort_array since 4.0.0 +setMethod("map_sort", + signature(x = "Column"), + function(x, asc = TRUE) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_sort", x@jc, asc) + column(jc) + } + #' @details #' \code{element_at}: Returns element of array at given index in \code{extraction} if #' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 26e81733055a6..2004530da88cb 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1216,6 +1216,10 @@ setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @name NULL setGeneric("map_zip_with", function(x, y, f) { standardGeneric("map_zip_with") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_sort", function(x, asc = TRUE) { standardGeneric("map_sort") }) + #' @rdname column_aggregate_functions #' @name NULL setGeneric("max_by", function(x, y) { standardGeneric("max_by") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 630781a57e444..652b81d7b7532 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1646,6 +1646,12 @@ test_that("column functions", { expected_entries <- list(as.environment(list(x = 1, y = 2, a = 3, b = 4))) expect_equal(result, expected_entries) + # Test map_sort + df <- createDataFrame(list(map1 = as.environment(list(c = 3, a = 1, b = 2)))) + result <- collect(select(df, map_concat(df[[1]])))[[1]] + expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) + expect_equal(result, expected_entries) + # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_entries(df$map)))[[1]] diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 133b7e036cd7c..cad72d7da24aa 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7081,6 +7081,25 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = Column.fn("sort_array", e, lit(asc)) + /** + * Sorts the input map in ascending order according to the natural ordering + * of the map keys. + * + * @group map_funcs + * @since 4.0.0 + */ + def map_sort(e: Column): Column = map_sort(e, asc = true) + + + /** + * Sorts the input map in ascending or descending order according to the natural ordering + * of the map keys. + * + * @group map_funcs + * @since 4.0.0 + */ + def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, lit(asc)) + /** * Returns the minimum value in the array. NaN is greater than any non-NaN elements for * double/float type. NULL elements are skipped. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index ee98a1aceea38..6fbee02997275 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -2525,6 +2525,10 @@ class PlanGenerationTestSuite fn.map_from_entries(fn.transform(fn.col("e"), (x, i) => fn.struct(i, x))) } + functionTest("map_sort") { + fn.map_sort(fn.col("f")) + } + functionTest("arrays_zip") { fn.arrays_zip(fn.col("e"), fn.sequence(lit(1), lit(20))) } diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index ca20ccfb73c56..438e1e7a9a88d 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -394,6 +394,7 @@ Map Functions map_from_entries map_keys map_values + map_sort str_to_map diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 72adfec33b1d6..53f3c537cbc41 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2004,6 +2004,13 @@ def map_values(col: "ColumnOrName") -> Column: map_values.__doc__ = pysparkfuncs.map_values.__doc__ +def map_sort(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns("map_sort", col) + + +map_sort.__doc__ = pysparkfuncs.map_sort.__doc__ + + def map_zip_with( col1: "ColumnOrName", col2: "ColumnOrName", diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 6320f9b922eef..226cca3f87f7b 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16839,6 +16839,53 @@ def map_concat( cols = cols[0] # type: ignore[assignment] return _invoke_function_over_seq_of_columns("map_concat", cols) # type: ignore[arg-type] +@_try_remote_functions +def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: + """ + Map function: Sorts the input map in ascending or descending order according + to the natural ordering of the map keys. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + Name of the column or expression. + asc : bool, optional + Whether to sort in ascending or descending order. If `asc` is True (default), + then the sorting is in ascending order. If False, then in descending order. + + Returns + ------- + :class:`~pyspark.sql.Column` + Sorted map. + + Examples + -------- + Example 1: Sorting a map in ascending order + + >>> import pyspark.sql.functions as sf + >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") + >>> df.select(sf.map_sort(df.data)).show() + +------------------------+ + | map_sort(data, true)| + +------------------------+ + |{1 -> a, 2 -> b, 3 -> c}| + +------------------------+ + + Example 2: Sorting a map in descending order + + >>> import pyspark.sql.functions as sf + >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") + >>> df.select(sf.map_sort(df.data, false)).show() + +------------------------+ + | map_sort(data, true)| + +------------------------+ + |{3 -> c, 2 -> b, 1 -> a}| + +------------------------+ + """ + return _invoke_function("map_sort", _to_java_column(col), asc) + @_try_remote_functions def sequence( diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index a736832c8ef99..74e8a5f2a90e1 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1440,6 +1440,13 @@ def test_map_concat(self): {1: "a", 2: "b", 3: "c"}, ) + def test_map_sort(self): + df = self.spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as map1") + self.assertEqual( + df.select(F.map_sort("map1").alias("map2")).first()[0], + {1: "a", 2: "b", 3: "c"}, + ) + def test_version(self): self.assertIsInstance(self.spark.range(1).select(F.version()).first()[0], str) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 18f4eadc0cec0..f64f88cfd9b65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -696,7 +696,7 @@ object FunctionRegistry { expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), expression[MapConcat]("map_concat"), - expression[SortMap]("sort_map"), + expression[MapSort]("map_sort"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality", true), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8088c97ded3b7..fae74bb1580ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -900,7 +900,7 @@ case class MapFromEntries(child: Expression) """, group = "map_funcs", since = "4.0.0") -case class SortMap(base: Expression, ascendingOrder: Expression) +case class MapSort(base: Expression, ascendingOrder: Expression) extends BinaryExpression with NullIntolerant with QueryErrorsBase { def this(e: Expression) = this(e, Literal(true)) @@ -1048,10 +1048,10 @@ case class SortMap(base: Expression, ascendingOrder: Expression) |""".stripMargin } - override def prettyName: String = "sort_map" + override def prettyName: String = "map_sort" override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression) - : SortMap = copy(base = newLeft, ascendingOrder = newRight) + : MapSort = copy(base = newLeft, ascendingOrder = newRight) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 302703d4497a2..3063b83d4dca1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -434,30 +434,30 @@ class CollectionExpressionsSuite Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3), MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType)) - checkEvaluation(new SortMap(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) - checkEvaluation(SortMap(intKey, Literal.create(false, BooleanType)), + checkEvaluation(new MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) + checkEvaluation(MapSort(intKey, Literal.create(false, BooleanType)), Map(3 -> 3, 2 -> 2, 1 -> 1)) - checkEvaluation(new SortMap(boolKey), Map(false -> 1, true -> 2)) - checkEvaluation(SortMap(boolKey, Literal.create(false, BooleanType)), + checkEvaluation(new MapSort(boolKey), Map(false -> 1, true -> 2)) + checkEvaluation(MapSort(boolKey, Literal.create(false, BooleanType)), Map(true -> 2, false -> 1)) - checkEvaluation(new SortMap(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) - checkEvaluation(SortMap(stringKey, Literal.create(false, BooleanType)), + checkEvaluation(new MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) + checkEvaluation(MapSort(stringKey, Literal.create(false, BooleanType)), Map("3" -> 3, "2" -> 2, "1" -> 1)) - checkEvaluation(new SortMap(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) - checkEvaluation(SortMap(arrayKey, Literal.create(false, BooleanType)), + checkEvaluation(new MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) + checkEvaluation(MapSort(arrayKey, Literal.create(false, BooleanType)), Map(Seq(3) -> 3, Seq(2) -> 2, Seq(1) -> 1)) - checkEvaluation(new SortMap(nestedArrayKey), + checkEvaluation(new MapSort(nestedArrayKey), Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3)) - checkEvaluation(SortMap(nestedArrayKey, Literal.create(false, BooleanType)), + checkEvaluation(MapSort(nestedArrayKey, Literal.create(false, BooleanType)), Map(Seq(Seq(3)) -> 3, Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1)) - checkEvaluation(new SortMap(structKey), + checkEvaluation(new MapSort(structKey), Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3)) - checkEvaluation(SortMap(structKey, Literal.create(false, BooleanType)), + checkEvaluation(MapSort(structKey, Literal.create(false, BooleanType)), Map(create_row(3) -> 3, create_row(2) -> 2, create_row(1) -> 1)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 933d0b3f89a7e..cd3d182841278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -6986,6 +6986,13 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = Column.fn("map_concat", cols: _*) + /** + * Sorts the input map in ascending order based on the natural order of map keys. + * @group map_funcs + * @since 4.0.0 + */ + def map_sort(e: Column): Column = Column.fn("map_sort", e) + // scalastyle:off line.size.limit /** * Parses a column containing a CSV string into a `StructType` with the specified schema. diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 4714e4f70668b..999cd68738484 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -215,6 +215,7 @@ | org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | | org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct> | +| org.apache.spark.sql.catalyst.expressions.MapSort | map_sort | SELECT map_sort(map(3, 'c', 1, 'a', 2, 'b'), true) | struct> | | org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct> | | org.apache.spark.sql.catalyst.expressions.MaskExpressionBuilder | mask | SELECT mask('abcd-EFGH-8765-4321') | struct | @@ -301,7 +302,6 @@ | org.apache.spark.sql.catalyst.expressions.Size | size | SELECT size(array('b', 'd', 'c', 'a')) | struct | | org.apache.spark.sql.catalyst.expressions.Slice | slice | SELECT slice(array(1, 2, 3, 4), 2, 2) | struct> | | org.apache.spark.sql.catalyst.expressions.SortArray | sort_array | SELECT sort_array(array('b', 'd', null, 'c', 'a'), true) | struct> | -| org.apache.spark.sql.catalyst.expressions.SortMap | sort_map | SELECT sort_map(map(3, 'c', 1, 'a', 2, 'b'), true) | struct> | | org.apache.spark.sql.catalyst.expressions.SoundEx | soundex | SELECT soundex('Miller') | struct | | org.apache.spark.sql.catalyst.expressions.SparkPartitionID | spark_partition_id | SELECT spark_partition_id() | struct | | org.apache.spark.sql.catalyst.expressions.SparkVersion | version | SELECT version() | struct | From 249e903d596d3803d4c8bfbbe9b6ecce7b29042c Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 28 Feb 2024 15:35:51 +0100 Subject: [PATCH 04/46] fix typoes --- R/pkg/R/functions.R | 2 +- python/pyspark/sql/functions/builtin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e3452d71682cc..69ea77b87b2e0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4538,7 +4538,7 @@ setMethod("map_sort", function(x, asc = TRUE) { jc <- callJStatic("org.apache.spark.sql.functions", "map_sort", x@jc, asc) column(jc) - } + }) #' @details #' \code{element_at}: Returns element of array at given index in \code{extraction} if diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 226cca3f87f7b..cc6cf3f0126be 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16877,7 +16877,7 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data, false)).show() + >>> df.select(sf.map_sort(df.data, False)).show() +------------------------+ | map_sort(data, true)| +------------------------+ From aaae8835463d25bf3d04a014115ae6177ce2e0c6 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 28 Feb 2024 15:43:24 +0100 Subject: [PATCH 05/46] fix scalastyle issue --- .../src/main/scala/org/apache/spark/sql/functions.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index cad72d7da24aa..15d8f4253eb92 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7082,18 +7082,16 @@ object functions { def sort_array(e: Column, asc: Boolean): Column = Column.fn("sort_array", e, lit(asc)) /** - * Sorts the input map in ascending order according to the natural ordering - * of the map keys. + * Sorts the input map in ascending order according to the natural ordering of the map keys. * * @group map_funcs * @since 4.0.0 */ def map_sort(e: Column): Column = map_sort(e, asc = true) - /** - * Sorts the input map in ascending or descending order according to the natural ordering - * of the map keys. + * Sorts the input map in ascending or descending order according to the natural ordering of the + * map keys. * * @group map_funcs * @since 4.0.0 From acaf95e3cdd767cd0708c6fe16e416c13ef6600c Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 28 Feb 2024 15:53:00 +0100 Subject: [PATCH 06/46] add proto golden files --- .../queries/function_map_sort.json | 29 ++++++++++++++++++ .../queries/function_map_sort.proto.bin | Bin 0 -> 183 bytes 2 files changed, 29 insertions(+) create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json create mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json new file mode 100644 index 0000000000000..81a9788d0fbae --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "map_sort", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "f" + } + }, { + "literal": { + "boolean": true + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin new file mode 100644 index 0000000000000000000000000000000000000000..57b823a5712988205bd9b2b37ce7f274fe5cdf62 GIT binary patch literal 183 zcmd;L5@3|tz{oX;k&8)yA*!2EsDrV%q^LBx#3nPvDk(EPGp|G^(F#N+S*7HcCgr5+ zq*xJ9VW*R7l~`1iSZM>)XQz{9m77>#1Jsk5m##xdtDR0d$atVqJ1I#iaV`#^-uUAD Tq7oriA!aVdG$9r)CJ9CWE( Date: Wed, 28 Feb 2024 17:26:05 +0100 Subject: [PATCH 07/46] fix python function call --- .../pyspark/sql/connect/functions/builtin.py | 4 +-- .../org/apache/spark/sql/functions.scala | 13 +++++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 25 +++++++++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 53f3c537cbc41..318e6f7887699 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2004,8 +2004,8 @@ def map_values(col: "ColumnOrName") -> Column: map_values.__doc__ = pysparkfuncs.map_values.__doc__ -def map_sort(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("map_sort", col) +def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: + return _invoke_function("map_sort", _to_col(col), lit(asc)) map_sort.__doc__ = pysparkfuncs.map_sort.__doc__ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cd3d182841278..2bc6db58333e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -6988,10 +6988,21 @@ object functions { /** * Sorts the input map in ascending order based on the natural order of map keys. + * + * @group map_funcs + * @since 4.0.0 + */ + def map_sort(e: Column): Column = map_sort(e, asc = true) + // TODO: add test for this + + /** + * Sorts the input map in ascending or descending order according to the natural ordering + * of the map keys. + * * @group map_funcs * @since 4.0.0 */ - def map_sort(e: Column): Column = Column.fn("map_sort", e) + def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, lit(asc)) // scalastyle:off line.size.limit /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e42f397cbfc29..cac0107e2443b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -780,6 +780,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("map_sort function") { + val df1 = Seq( + Map[Int, Int](2 -> 2, 1 -> 1, 3 -> 3) + ).toDF("a") + + checkAnswer( + df1.selectExpr("map_sort(a)"), + Seq( + Row(Map(1 -> 1, 2 -> 2, 3 -> 3)) + ) + ) + checkAnswer( + df1.selectExpr("map_sort(a, true)"), + Seq( + Row(Map(1 -> 1, 2 -> 2, 3 -> 3)) + ) + ) + checkAnswer( + df1.select(map_sort($"a", asc = false)), + Seq( + Row(Map(3 -> 3, 2 -> 2, 1 -> 1)) + ) + ) + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), From 7754c14c4deb54acba8e75f76371d5f56f8795f7 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 29 Feb 2024 09:48:23 +0100 Subject: [PATCH 08/46] fix ci errors --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- .../query-tests/explain-results/function_map_sort.explain | 2 ++ python/pyspark/sql/functions/builtin.py | 7 ++++--- 3 files changed, 7 insertions(+), 4 deletions(-) create mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 652b81d7b7532..fa87106c1f144 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1648,7 +1648,7 @@ test_that("column functions", { # Test map_sort df <- createDataFrame(list(map1 = as.environment(list(c = 3, a = 1, b = 2)))) - result <- collect(select(df, map_concat(df[[1]])))[[1]] + result <- collect(select(df, map_sort(df[[1]])))[[1]] expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) expect_equal(result, expected_entries) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain new file mode 100644 index 0000000000000..069b2ce65d187 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain @@ -0,0 +1,2 @@ +Project [map_sort(f#0, true) AS map_sort(f, true)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index cc6cf3f0126be..61bfa4db79d44 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16839,6 +16839,7 @@ def map_concat( cols = cols[0] # type: ignore[assignment] return _invoke_function_over_seq_of_columns("map_concat", cols) # type: ignore[arg-type] + @_try_remote_functions def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: """ @@ -16870,7 +16871,7 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: +------------------------+ | map_sort(data, true)| +------------------------+ - |{1 -> a, 2 -> b, 3 -> c}| + | {1 -> a, 2 -> b, ...| +------------------------+ Example 2: Sorting a map in descending order @@ -16879,9 +16880,9 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data, False)).show() +------------------------+ - | map_sort(data, true)| + | map_sort(data, false)| +------------------------+ - |{3 -> c, 2 -> b, 1 -> a}| + | {3 -> c, 2 -> b, ...| +------------------------+ """ return _invoke_function("map_sort", _to_java_column(col), asc) From f0ebf5dc5a4d9d118babb6865e7d7871d2f44d0b Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 29 Feb 2024 14:09:26 +0100 Subject: [PATCH 09/46] fix ci checks --- R/pkg/R/functions.R | 4 ++-- python/pyspark/sql/functions/builtin.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 69ea77b87b2e0..143277eab1417 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4524,7 +4524,7 @@ setMethod("map_zip_with", }) #' @details -#' \code{sort_array}: Sorts the input map in ascending or descending order according to +#' \code{map_sort}: Sorts the input map in ascending or descending order according to #' the natural ordering of the map keys. #' #' @rdname column_collection_functions @@ -4532,7 +4532,7 @@ setMethod("map_zip_with", #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases map_sort map_sort,Column-method -#' @note sort_array since 4.0.0 +#' @note map_sort since 4.0.0 setMethod("map_sort", signature(x = "Column"), function(x, asc = TRUE) { diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 61bfa4db79d44..0167f7fd2be93 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16868,22 +16868,22 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data)).show() - +------------------------+ - | map_sort(data, true)| - +------------------------+ - | {1 -> a, 2 -> b, ...| - +------------------------+ + +--------------------+ + |map_sort(data, true)| + +--------------------+ + |{1 -> a, 2 -> b, ...| + +--------------------+ Example 2: Sorting a map in descending order >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data, False)).show() - +------------------------+ - | map_sort(data, false)| - +------------------------+ - | {3 -> c, 2 -> b, ...| - +------------------------+ + +---------------------+ + |map_sort(data, false)| + +---------------------+ + | {3 -> c, 2 -> b, ...| + +---------------------+ """ return _invoke_function("map_sort", _to_java_column(col), asc) From 1f78167886a4cc2dee132732a0678d6d503967a0 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 12 Mar 2024 17:44:38 +0100 Subject: [PATCH 10/46] Optimized map-sort by switching to array sorting --- .../expressions/collectionOperations.scala | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fae74bb1580ac..b095fe483f25a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -947,7 +947,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) } override def nullSafeEval(array: Any, ascending: Any): Any = { - // put keys in a tree map and then read them back to build new k/v arrays + // put keys and their respective indices inside a tuple + // and sort them to extract new order k/v pairs val mapData = array.asInstanceOf[MapData] val numElements = mapData.numElements() @@ -960,17 +961,16 @@ case class MapSort(base: Expression, ascendingOrder: Expression) PhysicalDataType.ordering(keyType).reverse } - val treeMap = mutable.TreeMap.empty[Any, Int](ordering) - for (i <- 0 until numElements) { - treeMap.put(keys.get(i, keyType), i) - } + val sortedKeys = Array + .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], i)) + .sortBy(_._1)(ordering) val newKeys = new Array[Any](numElements) val newValues = new Array[Any](numElements) - treeMap.zipWithIndex.foreach { case ((_, originalIndex), sortedIndex) => - newKeys(sortedIndex) = keys.get(originalIndex, keyType) - newValues(sortedIndex) = values.get(originalIndex, valueType) + sortedKeys.zipWithIndex.foreach { case (elem, index) => + newKeys(index) = keys.get(elem._2, keyType) + newValues(index) = values.get(elem._2, valueType) } new ArrayBasedMapData(new GenericArrayData(newKeys), new GenericArrayData(newValues)) @@ -989,19 +989,22 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val numElements = ctx.freshName("numElements") val keys = ctx.freshName("keys") val values = ctx.freshName("values") - val treeMap = ctx.freshName("treeMap") + val sortArray = ctx.freshName("sortArray") val i = ctx.freshName("i") val o1 = ctx.freshName("o1") + val o1entry = ctx.freshName("o1entry") val o2 = ctx.freshName("o2") + val o2entry = ctx.freshName("o2entry") val c = ctx.freshName("c") val newKeys = ctx.freshName("newKeys") val newValues = ctx.freshName("newValues") - val mapEntry = ctx.freshName("mapEntry") val originalIndex = ctx.freshName("originalIndex") val boxedKeyType = CodeGenerator.boxedType(keyType) val javaKeyType = CodeGenerator.javaType(keyType) + val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, Integer>" + val comp = if (CodeGenerator.isPrimitiveType(keyType)) { val v1 = ctx.freshName("v1") val v2 = ctx.freshName("v2") @@ -1019,28 +1022,29 @@ case class MapSort(base: Expression, ascendingOrder: Expression) |ArrayData $keys = $base.keyArray(); |ArrayData $values = $base.valueArray(); | - |java.util.TreeMap<$boxedKeyType, Integer> $treeMap = new java.util.TreeMap<>( - | new java.util.Comparator() { - | @Override public int compare(Object $o1, Object $o2) { - | $comp; - | return $order ? $c : -$c; - | } - | } - |); + |Object[] $sortArray = new Object[$numElements]; | |for (int $i = 0; $i < $numElements; $i++) { - | $treeMap.put(${CodeGenerator.getValue(keys, keyType, i)}, $i); + | $sortArray[$i] = new $simpleEntryType( + | ${CodeGenerator.getValue(keys, keyType, i)}, $i); |} | + |java.util.Arrays.sort($sortArray, new java.util.Comparator() { + | @Override public int compare(Object $o1entry, Object $o2entry) { + | Object $o1 = (($simpleEntryType) $o1entry).getKey(); + | Object $o2 = (($simpleEntryType) $o2entry).getKey(); + | $comp; + | return $order ? $c : -$c; + | } + |}); + | |Object[] $newKeys = new Object[$numElements]; |Object[] $newValues = new Object[$numElements]; | - |int $i = 0; - |for (java.util.Map.Entry<$boxedKeyType, Integer> $mapEntry : $treeMap.entrySet()) { - | int $originalIndex = (Integer) $mapEntry.getValue(); + |for (int $i = 0; $i < $numElements; $i++) { + | int $originalIndex = (Integer) ((($simpleEntryType) $sortArray[$i]).getValue()); | $newKeys[$i] = ${CodeGenerator.getValue(keys, keyType, originalIndex)}; | $newValues[$i] = ${CodeGenerator.getValue(values, valueType, originalIndex)}; - | $i++; |} | |${ev.value} = new $arrayBasedMapData( From a5eb4807903f6dc8fbbc972b11359956044bb761 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Wed, 13 Mar 2024 10:17:36 +0100 Subject: [PATCH 11/46] Potential tests fix --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index fa87106c1f144..ca3353f4eb899 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1647,7 +1647,7 @@ test_that("column functions", { expect_equal(result, expected_entries) # Test map_sort - df <- createDataFrame(list(map1 = as.environment(list(c = 3, a = 1, b = 2)))) + df <- createDataFrame(list(List(map1 = as.environment(list(c = 3, a = 1, b = 2))))) result <- collect(select(df, map_sort(df[[1]])))[[1]] expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) expect_equal(result, expected_entries) From 9497f998b3ebb14a7cfc910100524af640530305 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Wed, 13 Mar 2024 10:53:15 +0100 Subject: [PATCH 12/46] Potential tests fix 2 --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index ca3353f4eb899..75fe342d4d487 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1647,7 +1647,7 @@ test_that("column functions", { expect_equal(result, expected_entries) # Test map_sort - df <- createDataFrame(list(List(map1 = as.environment(list(c = 3, a = 1, b = 2))))) + df <- createDataFrame(list(list(map1 = as.environment(list(c = 3, a = 1, b = 2))))) result <- collect(select(df, map_sort(df[[1]])))[[1]] expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) expect_equal(result, expected_entries) From 5e3822098ea68430238b71a0fe46bb7db189b363 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 14 Mar 2024 14:13:14 +0100 Subject: [PATCH 13/46] Allowed group by expression with Maps --- .../sql/catalyst/expressions/ExprUtils.scala | 9 ----- .../expressions/codegen/CodeGenerator.scala | 1 + .../optimizer/NormalizeFloatingNumbers.scala | 3 -- .../catalyst/plans/logical/LogicalPlan.scala | 1 - .../spark/sql/DataFrameAggregateSuite.scala | 36 +++++++++++++++++++ 5 files changed, 37 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 2bbe730d4cfb8..eaf10973e71de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -193,15 +193,6 @@ object ExprUtils extends QueryErrorsBase { messageParameters = Map("sqlExpr" -> expr.sql)) } - // Check if the data type of expr is orderable. - if (expr.dataType.existsRecursively(_.isInstanceOf[MapType])) { - expr.failAnalysis( - errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE", - messageParameters = Map( - "sqlExpr" -> toSQLExpr(expr), - "dataType" -> toSQLType(expr.dataType))) - } - if (!expr.deterministic) { // This is just a sanity check, our analysis rule PullOutNondeterministic should // already pull out those nondeterministic expressions and evaluate them in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d922a960fcd80..f7162f94dcf5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -657,6 +657,7 @@ class CodegenContext extends Logging { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.unsafe.types.ByteArray.compareBinary($c1, $c2)" case CalendarIntervalType => s"$c1.compareTo($c2)" + case map : MapType => "0" case NullType => "0" case array: ArrayType => val elementType = array.elementType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index f946fe76bde4d..73b80b3cbc8c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -98,9 +98,6 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) case ArrayType(et, _) => needNormalize(et) - // Currently MapType is not comparable and analyzer should fail earlier if this case happens. - case _: MapType => - throw SparkException.internalError("grouping/join/window partition keys cannot be map type.") case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e1121d1f9026e..611df854df6f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -417,7 +417,6 @@ object LogicalPlanIntegrity { .orElse(LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan)) .orElse(LogicalPlanIntegrity.validateSchemaOutput(previousPlan, currentPlan)) .orElse(LogicalPlanIntegrity.validateNoDanglingReferences(currentPlan)) - .orElse(LogicalPlanIntegrity.validateGroupByTypes(currentPlan)) .orElse(LogicalPlanIntegrity.validateAggregateExpressions(currentPlan)) .map(err => s"${err}\nPrevious schema:${previousPlan.output.mkString(", ")}" + s"\nPrevious plan: ${previousPlan.treeString}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ec589fa772419..3e7293d8508a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2155,6 +2155,42 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("Support GROUP BY for MapType") { + val numRows = 10 + val configurations = Seq( + // Seq.empty[(String, String)], // hash aggregate is used by default + Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", + // "spark.sql.TungstenAggregate.testFallbackStartsAt" -> "1, 10"), + // Seq("spark.sql.test.forceApplyObjectHashAggregate" -> "true"), + // Seq( + // "spark.sql.test.forceApplyObjectHashAggregate" -> "true", + // SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"), + "spark.sql.test.forceApplySortAggregate" -> "true") + ) + + // val dfSame = (0 until numRows) + // .map(_ => Tuple1(new MapType(IntegerType, IntegerType, false))) + // .toDF("c0") + + val tableName = "temp" + scala.util.Random.between(1, 10000) + + for (conf <- configurations) { + withSQLConf(conf: _*) { + sql(s"CREATE TABLE $tableName(id INT, arr MAP) USING PARQUET;"); + sql(s"INSERT INTO $tableName VALUES(1, MAP(1,1))") + sql(s"INSERT INTO $tableName VALUES(2, MAP(1,2))") + val res = sql(s"select count(*) from $tableName group by arr") + res.foreach { row => + println(row); + + } + // assert(createAggregate(dfSame).count() == 1) + } + } + + def createAggregate(df: DataFrame): DataFrame = df.groupBy("c0").agg(count("*")) + } + test("SPARK-46536 Support GROUP BY CalendarIntervalType") { val numRows = 50 val configurations = Seq( From 03a752d1368a0cce6e1e071cee849b2fc09d669e Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 14 Mar 2024 16:58:35 +0100 Subject: [PATCH 14/46] replaced map data type with arrays in test --- .../spark/sql/DataFrameAggregateSuite.scala | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 3e7293d8508a8..90c7f147bace3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2159,13 +2159,14 @@ class DataFrameAggregateSuite extends QueryTest val numRows = 10 val configurations = Seq( // Seq.empty[(String, String)], // hash aggregate is used by default - Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", + // Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", // "spark.sql.TungstenAggregate.testFallbackStartsAt" -> "1, 10"), // Seq("spark.sql.test.forceApplyObjectHashAggregate" -> "true"), // Seq( // "spark.sql.test.forceApplyObjectHashAggregate" -> "true", // SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"), - "spark.sql.test.forceApplySortAggregate" -> "true") + // "spark.sql.test.forceApplySortAggregate" -> "true", + "spark.sql.codegen.wholeStage" -> "false" ) // val dfSame = (0 until numRows) @@ -2174,19 +2175,21 @@ class DataFrameAggregateSuite extends QueryTest val tableName = "temp" + scala.util.Random.between(1, 10000) - for (conf <- configurations) { - withSQLConf(conf: _*) { - sql(s"CREATE TABLE $tableName(id INT, arr MAP) USING PARQUET;"); - sql(s"INSERT INTO $tableName VALUES(1, MAP(1,1))") - sql(s"INSERT INTO $tableName VALUES(2, MAP(1,2))") + // for (conf <- configurations) { + // withSQLConf(conf: _*) { + sql(s"CREATE TABLE $tableName(id INT, arr ARRAY) USING PARQUET;"); + sql(s"INSERT INTO $tableName VALUES(1, ARRAY(1,2))") + sql(s"INSERT INTO $tableName VALUES(2, ARRAY(1,2))") val res = sql(s"select count(*) from $tableName group by arr") res.foreach { row => println(row); } + + // assert(createAggregate(dfSame).count() == 1) - } - } + // } + // } def createAggregate(df: DataFrame): DataFrame = df.groupBy("c0").agg(count("*")) } From b80afed74c02638c1af13461cc2f4adfe37402f2 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Sun, 17 Mar 2024 10:58:52 +0100 Subject: [PATCH 15/46] Added codegen for map ordering --- .../expressions/codegen/CodeGenerator.scala | 63 ++++++++++++++- .../sql/catalyst/optimizer/Optimizer.scala | 23 ++++++ .../catalyst/plans/logical/LogicalPlan.scala | 17 ---- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../spark/sql/DataFrameAggregateSuite.scala | 80 +++++++------------ 5 files changed, 117 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f7162f94dcf5d..9cbc1c8f3ebab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -657,7 +657,6 @@ class CodegenContext extends Logging { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.unsafe.types.ByteArray.compareBinary($c1, $c2)" case CalendarIntervalType => s"$c1.compareTo($c2)" - case map : MapType => "0" case NullType => "0" case array: ArrayType => val elementType = array.elementType @@ -723,12 +722,74 @@ class CodegenContext extends Logging { } """ s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" + case map: MapType => + val compareFunc = freshName("compareMapData") + val funcCode = genCompMapData(map.keyType, map.valueType, compareFunc) + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError("compare", dataType) } + private def genCompMapData(keyType: DataType, + valueType: DataType, compareFunc : String): String = { + val keyArrayA = freshName("keyArrayA") + val keyArrayB = freshName("keyArrayB") + val valueArrayA = freshName("valueArrayA") + val valueArrayB = freshName("valueArrayB") + val minLength = freshName("minLength") + s""" + public int $compareFunc(MapData a, MapData b) { + int lengthA = a.numElements(); + int lengthB = b.numElements(); + ArrayData $keyArrayA = a.keyArray(); + ArrayData $valueArrayA = a.valueArray(); + ArrayData $keyArrayB = b.keyArray(); + ArrayData $valueArrayB = b.valueArray(); + int $minLength = (lengthA > lengthB) ? lengthB : lengthA; + for (int i = 0; i < $minLength; i++) { + ${genCompElementsAt(keyArrayA, keyArrayB, "i", keyType)} + ${genCompElementsAt(valueArrayA, valueArrayB, "i", valueType)} + } + + if (lengthA < lengthB) { + return -1; + } else if (lengthA > lengthB) { + return 1; + } + return 0; + } + """ + } + + private def genCompElementsAt(arrayA: String, arrayB: String, i: String, + elementType : DataType): String = { + val elementA = freshName("elementA") + val isNullA = freshName("isNullA") + val elementB = freshName("elementB") + val isNullB = freshName("isNullB") + val jt = javaType(elementType); + s""" + boolean $isNullA = $arrayA.isNullAt($i); + boolean $isNullB = $arrayB.isNullAt($i); + if ($isNullA && $isNullB) { + // Nothing + } else if ($isNullA) { + return -1; + } else if ($isNullB) { + return 1; + } else { + $jt $elementA = ${getValue(arrayA, elementType, i)}; + $jt $elementB = ${getValue(arrayB, elementType, i)}; + int comp = ${genComp(elementType, elementA, elementB)}; + if (comp != 0) { + return comp; + } + } + """ + } + /** * Generates code for greater of two expressions. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 46d3043df3eb9..1bd6a391f0abd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -196,6 +196,7 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceDeduplicateWithAggregate) :: Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions, + InsertMapSortInGroupingExpressions, RemoveRepetitionFromGroupExpressions) :: Nil ++ operatorOptimizationBatch) :+ Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo) :+ @@ -2469,3 +2470,25 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } } } + +/** + * Adds MapSort to group expressions containing map columns, as the key/value paris need to be + * in the correct order before grouping: + * SELECT COUNT(*) FROM TABLE GROUP BY map_column => + * SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column) + */ +object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(AGGREGATE), ruleId) { + case a @ Aggregate(groupingExpr, x, b) => + val newGrouping = groupingExpr.map { expr => + (expr, expr.dataType) match { + case (_: MapSort, _) => expr + case (_, _: MapType) => + MapSort(expr, Literal.TrueLiteral) + case _ => expr + } + } + a.copy(groupingExpressions = newGrouping) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 611df854df6f0..7b926a51d7aaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -348,23 +348,6 @@ object LogicalPlanIntegrity { }.flatten } - /** - * Validate that the grouping key types in Aggregate plans are valid. - * Returns an error message if the check fails, or None if it succeeds. - */ - def validateGroupByTypes(plan: LogicalPlan): Option[String] = { - plan.collectFirst { - case a @ Aggregate(groupingExprs, _, _) => - val badExprs = groupingExprs.filter(_.dataType.isInstanceOf[MapType]).map(_.toString) - if (badExprs.nonEmpty) { - Some(s"Grouping expressions ${badExprs.mkString(", ")} cannot be of type Map " + - s"for plan:\n ${a.treeString}") - } else { - None - } - }.flatten - } - /** * Validate that the aggregation expressions in Aggregate plans are valid. * Returns an error message if the check fails, or None if it succeeds. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 08f728da2e9dd..778d56788e89e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -126,6 +126,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" :: "org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" :: "org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" :: + "org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" :: "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDown" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 90c7f147bace3..d14e12308b961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2155,47 +2155,8 @@ class DataFrameAggregateSuite extends QueryTest ) } - test("Support GROUP BY for MapType") { - val numRows = 10 - val configurations = Seq( - // Seq.empty[(String, String)], // hash aggregate is used by default - // Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", - // "spark.sql.TungstenAggregate.testFallbackStartsAt" -> "1, 10"), - // Seq("spark.sql.test.forceApplyObjectHashAggregate" -> "true"), - // Seq( - // "spark.sql.test.forceApplyObjectHashAggregate" -> "true", - // SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"), - // "spark.sql.test.forceApplySortAggregate" -> "true", - "spark.sql.codegen.wholeStage" -> "false" - ) - - // val dfSame = (0 until numRows) - // .map(_ => Tuple1(new MapType(IntegerType, IntegerType, false))) - // .toDF("c0") - - val tableName = "temp" + scala.util.Random.between(1, 10000) - - // for (conf <- configurations) { - // withSQLConf(conf: _*) { - sql(s"CREATE TABLE $tableName(id INT, arr ARRAY) USING PARQUET;"); - sql(s"INSERT INTO $tableName VALUES(1, ARRAY(1,2))") - sql(s"INSERT INTO $tableName VALUES(2, ARRAY(1,2))") - val res = sql(s"select count(*) from $tableName group by arr") - res.foreach { row => - println(row); - - } - - - // assert(createAggregate(dfSame).count() == 1) - // } - // } - - def createAggregate(df: DataFrame): DataFrame = df.groupBy("c0").agg(count("*")) - } - - test("SPARK-46536 Support GROUP BY CalendarIntervalType") { - val numRows = 50 + private def assertAggregateOnDataframe(dfSeq: Seq[DataFrame], + expected: Seq[Int], aggregateColumn: String): Unit = { val configurations = Seq( Seq.empty[(String, String)], // hash aggregate is used by default Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN", @@ -2207,6 +2168,34 @@ class DataFrameAggregateSuite extends QueryTest Seq("spark.sql.test.forceApplySortAggregate" -> "true") ) + for ((df, index) <- dfSeq.zipWithIndex) { + for (conf <- configurations) { + withSQLConf(conf: _*) { + assert(createAggregate(df).count() == expected(index)) + } + } + } + + def createAggregate(df: DataFrame): DataFrame = df.groupBy(aggregateColumn).agg(count("*")) + } + + test("SPARK-47430 Support GROUP BY MapType") { + val numRows = 50 + + val dfSame = (0 until numRows) + .map(_ => Tuple1(Map(1 -> 1))) + .toDF("m0") + + val dfDifferent = (0 until numRows) + .map(i => Tuple1(Map(i -> i))) + .toDF("m0") + + assertAggregateOnDataframe(Seq(dfSame, dfDifferent), Seq(1, numRows), "m0") + } + + test("SPARK-46536 Support GROUP BY CalendarIntervalType") { + val numRows = 50 + val dfSame = (0 until numRows) .map(_ => Tuple1(new CalendarInterval(1, 2, 3))) .toDF("c0") @@ -2215,14 +2204,7 @@ class DataFrameAggregateSuite extends QueryTest .map(i => Tuple1(new CalendarInterval(i, i, i))) .toDF("c0") - for (conf <- configurations) { - withSQLConf(conf: _*) { - assert(createAggregate(dfSame).count() == 1) - assert(createAggregate(dfDifferent).count() == numRows) - } - } - - def createAggregate(df: DataFrame): DataFrame = df.groupBy("c0").agg(count("*")) + assertAggregateOnDataframe(Seq(dfSame, dfDifferent), Seq(1, numRows), "c0") } test("SPARK-46779: Group by subquery with a cached relation") { From 5e7a033c0d02d4609f728f35bd6f2cc7245c82fe Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Sun, 17 Mar 2024 12:25:56 +0100 Subject: [PATCH 16/46] Removed TODOs and changed parmIndex to ordinal --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/functions.scala | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b095fe483f25a..f6e05ba624a67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -921,7 +921,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "2", + "paramIndex" -> ordinalNumber(1), "requiredType" -> toSQLType(BooleanType), "inputSql" -> toSQLExpr(ascendingOrder), "inputType" -> toSQLType(ascendingOrder.dataType)) @@ -939,8 +939,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( - "paramIndex" -> "1", - "requiredType" -> toSQLType(ArrayType), + "paramIndex" -> ordinalNumber(0), + "requiredType" -> toSQLType(MapType), "inputSql" -> toSQLExpr(base), "inputType" -> toSQLType(base.dataType)) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2bc6db58333e3..dc987e1083dc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -6993,7 +6993,6 @@ object functions { * @since 4.0.0 */ def map_sort(e: Column): Column = map_sort(e, asc = true) - // TODO: add test for this /** * Sorts the input map in ascending or descending order according to the natural ordering From ab70f1e44e1672f08d5ae42cd5cd8c33c5ea1f7f Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 10:36:05 +0100 Subject: [PATCH 17/46] Shortened map sort function and added more docs --- .../expressions/collectionOperations.scala | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f6e05ba624a67..f02f444ec7f7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -891,7 +891,16 @@ case class MapFromEntries(child: Expression) @ExpressionDescription( usage = """ _FUNC_(map[, ascendingOrder]) - Sorts the input map in ascending or descending order - according to the natural ordering of the map keys. + according to the natural ordering of the map keys. The sorting algorithm used is + an adaptive, stable and iterative merge sort algorithm. If the input map is empty, + function returns an empty map. + """, + arguments = + """ + Arguments: + * map - an expression. The map that will be sorted. + * ascendingOrder - an expression. The ordering in which the map will be sorted. + This can be either ascending or descending element order. """, examples = """ Examples: @@ -961,19 +970,13 @@ case class MapSort(base: Expression, ascendingOrder: Expression) PhysicalDataType.ordering(keyType).reverse } - val sortedKeys = Array - .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], i)) + val sortedMap = Array + .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], + values.get(i, valueType).asInstanceOf[Any])) .sortBy(_._1)(ordering) - val newKeys = new Array[Any](numElements) - val newValues = new Array[Any](numElements) - - sortedKeys.zipWithIndex.foreach { case (elem, index) => - newKeys(index) = keys.get(elem._2, keyType) - newValues(index) = values.get(elem._2, valueType) - } - - new ArrayBasedMapData(new GenericArrayData(newKeys), new GenericArrayData(newValues)) + new ArrayBasedMapData(new GenericArrayData(sortedMap.map(_._1)), + new GenericArrayData(sortedMap.map(_._2))) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { From e79d65cbba4087d166cc1ec859f7ba6dc01e9aef Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 13:38:03 +0100 Subject: [PATCH 18/46] updated map_sort test suite --- .../spark/sql/DataFrameFunctionsSuite.scala | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index cac0107e2443b..6034bd5cc9cb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -25,7 +25,7 @@ import java.sql.{Date, Timestamp} import scala.util.Random import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkRuntimeException} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.{Alias, ArraysZip, AttributeReference, Expression, NamedExpression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.Cast._ @@ -803,6 +803,70 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Map(3 -> 3, 2 -> 2, 1 -> 1)) ) ) + + val df2 = Seq(Map.empty[Int, Int]).toDF("a") + + checkAnswer( + df2.selectExpr("map_sort(a, true)"), + Seq(Row(Map())) + ) + + checkError( + exception = intercept[AnalysisException] { + df2.orderBy("a") + }, + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + parameters = Map( + "functionName" -> "`sortorder`", + "dataType" -> "\"MAP\"", + "sqlExpr" -> "\"a ASC NULLS FIRST\"") + ) + + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT map_sort(map(null, 1))").collect() + }, + errorClass = "NULL_MAP_KEY" + ) + + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT map_sort(map(1, 1, 2, 2, 1, 1))").collect() + }, + errorClass = "DUPLICATED_MAP_KEY", + parameters = Map( + "key" -> "1", + "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"" + ) + ) + + checkError( + exception = intercept[ExtendedAnalysisException] { + sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", + "paramIndex" -> "second", "inputSql" -> "\"asc\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"BOOLEAN\"" + ), + queryContext = Array(ExpectedContext("", "", 7, 35, "map_sort(map(1,1,2,2), \"asc\")")) + ) + + checkError( + exception = intercept[ExtendedAnalysisException] { + sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", + "paramIndex" -> "second", "inputSql" -> "\"asc\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"BOOLEAN\"" + ), + queryContext = Array(ExpectedContext("", "", 7, 35, "map_sort(map(1,1,2,2), \"asc\")")) + ) } test("sort_array/array_sort functions") { From 28d6f703086ff9296349894d194fe0f87ca30031 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 15:08:53 +0100 Subject: [PATCH 19/46] Added map normalization and import cleanup --- .../optimizer/NormalizeFloatingNumbers.scala | 14 ++++++++++++++ .../sql/catalyst/plans/logical/LogicalPlan.scala | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 73b80b3cbc8c0..40364d3fc4bd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -98,6 +98,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) case ArrayType(et, _) => needNormalize(et) + case MapType(kt, vt, _) => needNormalize(kt) && needNormalize(vt) case _ => false } @@ -141,6 +142,19 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { val function = normalize(lv) KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv)))) + case _ if expr.dataType.isInstanceOf[MapType] => + val MapType(kt, vt, containsNull) = expr.dataType + val lv1 = NamedLambdaVariable("arg", kt, containsNull) + val lv2 = NamedLambdaVariable("arg", vt, containsNull) + val functionL1 = normalize(lv1) + val functionL2 = normalize(lv2) + KnownFloatingPointNormalized( + ArrayTransform( + ArrayTransform(expr, LambdaFunction(functionL1, Seq(lv1))), + LambdaFunction(functionL2, Seq(lv2)), + ) + ) + case _ => throw SparkException.internalError(s"fail to normalize $expr") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7b926a51d7aaa..b989233da6740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, U import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{MapType, StructType} +import org.apache.spark.sql.types.StructType abstract class LogicalPlan From a43535539ba8b3f5a81aa6200ffc791964bd3f33 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 16:08:19 +0100 Subject: [PATCH 20/46] Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala Co-authored-by: Maxim Gekk --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6034bd5cc9cb2..67ebe7a89e124 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -847,7 +847,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", - "paramIndex" -> "second", "inputSql" -> "\"asc\"", + "paramIndex" -> "second", + "inputSql" -> "\"asc\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"BOOLEAN\"" ), From c9901d08f83cc60961993fe3a64acebba168ea89 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 16:08:35 +0100 Subject: [PATCH 21/46] Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala Co-authored-by: Maxim Gekk --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 67ebe7a89e124..54e1ec53c25c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -862,7 +862,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", parameters = Map( "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", - "paramIndex" -> "second", "inputSql" -> "\"asc\"", + "paramIndex" -> "second", + "inputSql" -> "\"asc\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"BOOLEAN\"" ), From da6a710b7ddb56068562ddccfddc6329ec16f4d9 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 18 Mar 2024 16:27:48 +0100 Subject: [PATCH 22/46] docs fix --- .../catalyst/expressions/collectionOperations.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f02f444ec7f7d..6d1e72b5970e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -891,16 +891,16 @@ case class MapFromEntries(child: Expression) @ExpressionDescription( usage = """ _FUNC_(map[, ascendingOrder]) - Sorts the input map in ascending or descending order - according to the natural ordering of the map keys. The sorting algorithm used is - an adaptive, stable and iterative merge sort algorithm. If the input map is empty, - function returns an empty map. + according to the natural ordering of the map keys. The algorithm used for sorting is + an adaptive, stable and iterative algorithm. If the input map is empty, function + returns an empty map. """, arguments = """ Arguments: - * map - an expression. The map that will be sorted. - * ascendingOrder - an expression. The ordering in which the map will be sorted. - This can be either ascending or descending element order. + * map - The map that will be sorted. + * ascendingOrder - A boolean value describing the order in which the map will be sorted. + This can be either be ascending (true) or descending (false). """, examples = """ Examples: From 81008c218917e2b590111b132ed0b7008bedf305 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 19 Mar 2024 09:26:16 +0100 Subject: [PATCH 23/46] Updated codegen and removed once test-case --- .../expressions/collectionOperations.scala | 15 ++++++++------- .../spark/sql/DataFrameFunctionsSuite.scala | 11 ----------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6d1e72b5970e3..896620eec9ac8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -956,8 +956,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) } override def nullSafeEval(array: Any, ascending: Any): Any = { - // put keys and their respective indices inside a tuple - // and sort them to extract new order k/v pairs + // put keys and their respective values inside a tuple and sort them + // according to the key ordering. Extract the new sorted k/v pairs to form a sorted map val mapData = array.asInstanceOf[MapData] val numElements = mapData.numElements() @@ -1004,9 +1004,10 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val originalIndex = ctx.freshName("originalIndex") val boxedKeyType = CodeGenerator.boxedType(keyType) + val boxedValueType = CodeGenerator.boxedType(valueType) val javaKeyType = CodeGenerator.javaType(keyType) - val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, Integer>" + val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, $boxedValueType>" val comp = if (CodeGenerator.isPrimitiveType(keyType)) { val v1 = ctx.freshName("v1") @@ -1029,7 +1030,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) | |for (int $i = 0; $i < $numElements; $i++) { | $sortArray[$i] = new $simpleEntryType( - | ${CodeGenerator.getValue(keys, keyType, i)}, $i); + | ${CodeGenerator.getValue(keys, keyType, i)}, + | ${CodeGenerator.getValue(values, valueType, i)}); |} | |java.util.Arrays.sort($sortArray, new java.util.Comparator() { @@ -1045,9 +1047,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) |Object[] $newValues = new Object[$numElements]; | |for (int $i = 0; $i < $numElements; $i++) { - | int $originalIndex = (Integer) ((($simpleEntryType) $sortArray[$i]).getValue()); - | $newKeys[$i] = ${CodeGenerator.getValue(keys, keyType, originalIndex)}; - | $newValues[$i] = ${CodeGenerator.getValue(values, valueType, originalIndex)}; + | $newKeys[$i] = (($simpleEntryType) $sortArray[$i]).getKey(); + | $newValues[$i] = (($simpleEntryType) $sortArray[$i]).getValue(); |} | |${ev.value} = new $arrayBasedMapData( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 54e1ec53c25c6..e5953e59a51b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -829,17 +829,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "NULL_MAP_KEY" ) - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT map_sort(map(1, 1, 2, 2, 1, 1))").collect() - }, - errorClass = "DUPLICATED_MAP_KEY", - parameters = Map( - "key" -> "1", - "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"" - ) - ) - checkError( exception = intercept[ExtendedAnalysisException] { sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() From 86b29c5ca6171529f18e611cda899f301c5aee0c Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 19 Mar 2024 09:33:26 +0100 Subject: [PATCH 24/46] Update python/pyspark/sql/functions/builtin.py Co-authored-by: Ruifeng Zheng --- python/pyspark/sql/functions/builtin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 0167f7fd2be93..0832f73785cd3 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16867,7 +16867,7 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data)).show() + >>> df.select(sf.map_sort(df.data)).show(truncate=False) +--------------------+ |map_sort(data, true)| +--------------------+ From c08ab6c027f3eaf4e6240d92c6d946909d0e570c Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 19 Mar 2024 09:35:21 +0100 Subject: [PATCH 25/46] Updated 'select.show' to give more info in map_sort desc --- python/pyspark/sql/functions/builtin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 0832f73785cd3..d206197996a94 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16878,7 +16878,7 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data, False)).show() + >>> df.select(sf.map_sort(df.data, False)).show(truncate=False) +---------------------+ |map_sort(data, false)| +---------------------+ From 31a797c34925fa651d0d911ce169335dcd75f4c2 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 19 Mar 2024 21:07:56 +0100 Subject: [PATCH 26/46] Restructured docs, removed unused variable and refactored code --- python/pyspark/sql/functions/builtin.py | 20 +++++++++---------- .../expressions/collectionOperations.scala | 5 ++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index d206197996a94..8710a7c6bb306 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16868,22 +16868,22 @@ def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data)).show(truncate=False) - +--------------------+ - |map_sort(data, true)| - +--------------------+ - |{1 -> a, 2 -> b, ...| - +--------------------+ + +------------------------+ + |map_sort(data, true) | + +------------------------+ + |{1 -> a, 2 -> b, 3 -> c}| + +------------------------+ Example 2: Sorting a map in descending order >>> import pyspark.sql.functions as sf >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") >>> df.select(sf.map_sort(df.data, False)).show(truncate=False) - +---------------------+ - |map_sort(data, false)| - +---------------------+ - | {3 -> c, 2 -> b, ...| - +---------------------+ + +------------------------+ + |map_sort(data, false) | + +------------------------+ + |{3 -> c, 2 -> b, 1 -> a}| + +------------------------+ """ return _invoke_function("map_sort", _to_java_column(col), asc) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 896620eec9ac8..3ed711d477621 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -922,7 +922,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) override def dataType: DataType = base.dataType override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case MapType(kt, _, _) if RowOrdering.isOrderable(kt) => + case m: MapType if RowOrdering.isOrderable(m.keyType) => ascendingOrder match { case Literal(_: Boolean, BooleanType) => TypeCheckResult.TypeCheckSuccess @@ -936,7 +936,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) "inputType" -> toSQLType(ascendingOrder.dataType)) ) } - case MapType(_, _, _) => + case _: MapType => DataTypeMismatch( errorSubClass = "INVALID_ORDERING_TYPE", messageParameters = Map( @@ -1001,7 +1001,6 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val c = ctx.freshName("c") val newKeys = ctx.freshName("newKeys") val newValues = ctx.freshName("newValues") - val originalIndex = ctx.freshName("originalIndex") val boxedKeyType = CodeGenerator.boxedType(keyType) val boxedValueType = CodeGenerator.boxedType(valueType) From 69e3b48f7a8a539a3d9a1c968d5ff2e33e4b367c Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 12:20:48 +0100 Subject: [PATCH 27/46] Removed map_sort function but left the MapSort expression --- R/pkg/NAMESPACE | 1 - R/pkg/R/functions.R | 17 ------- R/pkg/tests/fulltests/test_sparkSQL.R | 6 --- .../org/apache/spark/sql/functions.scala | 17 ------- .../spark/sql/PlanGenerationTestSuite.scala | 4 -- .../explain-results/function_map_sort.explain | 2 - .../queries/function_map_sort.json | 29 ----------- .../queries/function_map_sort.proto.bin | Bin 183 -> 0 bytes .../reference/pyspark.sql/functions.rst | 1 - .../pyspark/sql/connect/functions/builtin.py | 7 --- python/pyspark/sql/functions/builtin.py | 48 ------------------ python/pyspark/sql/tests/test_functions.py | 7 --- .../catalyst/analysis/FunctionRegistry.scala | 1 - .../org/apache/spark/sql/functions.scala | 17 ------- .../sql-functions/sql-expression-schema.md | 1 - 15 files changed, 158 deletions(-) delete mode 100644 connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain delete mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json delete mode 100644 connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index bdbcfa552448b..c5668d1739b17 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -361,7 +361,6 @@ exportMethods("%<=>%", "map_keys", "map_values", "map_zip_with", - "map_sort", "max", "max_by", "md5", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 143277eab1417..5106a83bd0ec4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -4523,23 +4523,6 @@ setMethod("map_zip_with", ) }) -#' @details -#' \code{map_sort}: Sorts the input map in ascending or descending order according to -#' the natural ordering of the map keys. -#' -#' @rdname column_collection_functions -#' @param asc a logical flag indicating the sorting order. -#' TRUE, sorting is in ascending order. -#' FALSE, sorting is in descending order. -#' @aliases map_sort map_sort,Column-method -#' @note map_sort since 4.0.0 -setMethod("map_sort", - signature(x = "Column"), - function(x, asc = TRUE) { - jc <- callJStatic("org.apache.spark.sql.functions", "map_sort", x@jc, asc) - column(jc) - }) - #' @details #' \code{element_at}: Returns element of array at given index in \code{extraction} if #' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 75fe342d4d487..630781a57e444 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1646,12 +1646,6 @@ test_that("column functions", { expected_entries <- list(as.environment(list(x = 1, y = 2, a = 3, b = 4))) expect_equal(result, expected_entries) - # Test map_sort - df <- createDataFrame(list(list(map1 = as.environment(list(c = 3, a = 1, b = 2))))) - result <- collect(select(df, map_sort(df[[1]])))[[1]] - expected_entries <- list(as.environment(list(a = 1, b = 2, c = 3))) - expect_equal(result, expected_entries) - # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_entries(df$map)))[[1]] diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 15d8f4253eb92..133b7e036cd7c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7081,23 +7081,6 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = Column.fn("sort_array", e, lit(asc)) - /** - * Sorts the input map in ascending order according to the natural ordering of the map keys. - * - * @group map_funcs - * @since 4.0.0 - */ - def map_sort(e: Column): Column = map_sort(e, asc = true) - - /** - * Sorts the input map in ascending or descending order according to the natural ordering of the - * map keys. - * - * @group map_funcs - * @since 4.0.0 - */ - def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, lit(asc)) - /** * Returns the minimum value in the array. NaN is greater than any non-NaN elements for * double/float type. NULL elements are skipped. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 6fbee02997275..ee98a1aceea38 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -2525,10 +2525,6 @@ class PlanGenerationTestSuite fn.map_from_entries(fn.transform(fn.col("e"), (x, i) => fn.struct(i, x))) } - functionTest("map_sort") { - fn.map_sort(fn.col("f")) - } - functionTest("arrays_zip") { fn.arrays_zip(fn.col("e"), fn.sequence(lit(1), lit(20))) } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain deleted file mode 100644 index 069b2ce65d187..0000000000000 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_map_sort.explain +++ /dev/null @@ -1,2 +0,0 @@ -Project [map_sort(f#0, true) AS map_sort(f, true)#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json deleted file mode 100644 index 81a9788d0fbae..0000000000000 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "common": { - "planId": "1" - }, - "project": { - "input": { - "common": { - "planId": "0" - }, - "localRelation": { - "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" - } - }, - "expressions": [{ - "unresolvedFunction": { - "functionName": "map_sort", - "arguments": [{ - "unresolvedAttribute": { - "unparsedIdentifier": "f" - } - }, { - "literal": { - "boolean": true - } - }] - } - }] - } -} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_map_sort.proto.bin deleted file mode 100644 index 57b823a5712988205bd9b2b37ce7f274fe5cdf62..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 183 zcmd;L5@3|tz{oX;k&8)yA*!2EsDrV%q^LBx#3nPvDk(EPGp|G^(F#N+S*7HcCgr5+ zq*xJ9VW*R7l~`1iSZM>)XQz{9m77>#1Jsk5m##xdtDR0d$atVqJ1I#iaV`#^-uUAD Tq7oriA!aVdG$9r)CJ9CWE( Column: map_values.__doc__ = pysparkfuncs.map_values.__doc__ -def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: - return _invoke_function("map_sort", _to_col(col), lit(asc)) - - -map_sort.__doc__ = pysparkfuncs.map_sort.__doc__ - - def map_zip_with( col1: "ColumnOrName", col2: "ColumnOrName", diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 8710a7c6bb306..6320f9b922eef 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16840,54 +16840,6 @@ def map_concat( return _invoke_function_over_seq_of_columns("map_concat", cols) # type: ignore[arg-type] -@_try_remote_functions -def map_sort(col: "ColumnOrName", asc: bool = True) -> Column: - """ - Map function: Sorts the input map in ascending or descending order according - to the natural ordering of the map keys. - - .. versionadded:: 4.0.0 - - Parameters - ---------- - col : :class:`~pyspark.sql.Column` or str - Name of the column or expression. - asc : bool, optional - Whether to sort in ascending or descending order. If `asc` is True (default), - then the sorting is in ascending order. If False, then in descending order. - - Returns - ------- - :class:`~pyspark.sql.Column` - Sorted map. - - Examples - -------- - Example 1: Sorting a map in ascending order - - >>> import pyspark.sql.functions as sf - >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data)).show(truncate=False) - +------------------------+ - |map_sort(data, true) | - +------------------------+ - |{1 -> a, 2 -> b, 3 -> c}| - +------------------------+ - - Example 2: Sorting a map in descending order - - >>> import pyspark.sql.functions as sf - >>> df = spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as data") - >>> df.select(sf.map_sort(df.data, False)).show(truncate=False) - +------------------------+ - |map_sort(data, false) | - +------------------------+ - |{3 -> c, 2 -> b, 1 -> a}| - +------------------------+ - """ - return _invoke_function("map_sort", _to_java_column(col), asc) - - @_try_remote_functions def sequence( start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 74e8a5f2a90e1..a736832c8ef99 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1440,13 +1440,6 @@ def test_map_concat(self): {1: "a", 2: "b", 3: "c"}, ) - def test_map_sort(self): - df = self.spark.sql("SELECT map(3, 'c', 1, 'a', 2, 'b') as map1") - self.assertEqual( - df.select(F.map_sort("map1").alias("map2")).first()[0], - {1: "a", 2: "b", 3: "c"}, - ) - def test_version(self): self.assertIsInstance(self.spark.range(1).select(F.version()).first()[0], str) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f64f88cfd9b65..b165d20d0b4fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -696,7 +696,6 @@ object FunctionRegistry { expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), expression[MapConcat]("map_concat"), - expression[MapSort]("map_sort"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality", true), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index dc987e1083dc8..933d0b3f89a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -6986,23 +6986,6 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = Column.fn("map_concat", cols: _*) - /** - * Sorts the input map in ascending order based on the natural order of map keys. - * - * @group map_funcs - * @since 4.0.0 - */ - def map_sort(e: Column): Column = map_sort(e, asc = true) - - /** - * Sorts the input map in ascending or descending order according to the natural ordering - * of the map keys. - * - * @group map_funcs - * @since 4.0.0 - */ - def map_sort(e: Column, asc: Boolean): Column = Column.fn("map_sort", e, lit(asc)) - // scalastyle:off line.size.limit /** * Parses a column containing a CSV string into a `StructType` with the specified schema. diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 999cd68738484..e20db3b49589c 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -215,7 +215,6 @@ | org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | | org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct> | -| org.apache.spark.sql.catalyst.expressions.MapSort | map_sort | SELECT map_sort(map(3, 'c', 1, 'a', 2, 'b'), true) | struct> | | org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct> | | org.apache.spark.sql.catalyst.expressions.MaskExpressionBuilder | mask | SELECT mask('abcd-EFGH-8765-4321') | struct | From 8d9ac51d95669efe6a4253e0071fb2ec665ae2d6 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 13:37:35 +0100 Subject: [PATCH 28/46] aditional erasions --- R/pkg/R/generics.R | 4 - .../spark/sql/DataFrameFunctionsSuite.scala | 82 +------------------ 2 files changed, 1 insertion(+), 85 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 2004530da88cb..26e81733055a6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1216,10 +1216,6 @@ setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @name NULL setGeneric("map_zip_with", function(x, y, f) { standardGeneric("map_zip_with") }) -#' @rdname column_collection_functions -#' @name NULL -setGeneric("map_sort", function(x, asc = TRUE) { standardGeneric("map_sort") }) - #' @rdname column_aggregate_functions #' @name NULL setGeneric("max_by", function(x, y) { standardGeneric("max_by") }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e5953e59a51b1..e42f397cbfc29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -25,7 +25,7 @@ import java.sql.{Date, Timestamp} import scala.util.Random import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkRuntimeException} -import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.{Alias, ArraysZip, AttributeReference, Expression, NamedExpression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.Cast._ @@ -780,86 +780,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } - test("map_sort function") { - val df1 = Seq( - Map[Int, Int](2 -> 2, 1 -> 1, 3 -> 3) - ).toDF("a") - - checkAnswer( - df1.selectExpr("map_sort(a)"), - Seq( - Row(Map(1 -> 1, 2 -> 2, 3 -> 3)) - ) - ) - checkAnswer( - df1.selectExpr("map_sort(a, true)"), - Seq( - Row(Map(1 -> 1, 2 -> 2, 3 -> 3)) - ) - ) - checkAnswer( - df1.select(map_sort($"a", asc = false)), - Seq( - Row(Map(3 -> 3, 2 -> 2, 1 -> 1)) - ) - ) - - val df2 = Seq(Map.empty[Int, Int]).toDF("a") - - checkAnswer( - df2.selectExpr("map_sort(a, true)"), - Seq(Row(Map())) - ) - - checkError( - exception = intercept[AnalysisException] { - df2.orderBy("a") - }, - errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", - parameters = Map( - "functionName" -> "`sortorder`", - "dataType" -> "\"MAP\"", - "sqlExpr" -> "\"a ASC NULLS FIRST\"") - ) - - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT map_sort(map(null, 1))").collect() - }, - errorClass = "NULL_MAP_KEY" - ) - - checkError( - exception = intercept[ExtendedAnalysisException] { - sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", - "paramIndex" -> "second", - "inputSql" -> "\"asc\"", - "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"" - ), - queryContext = Array(ExpectedContext("", "", 7, 35, "map_sort(map(1,1,2,2), \"asc\")")) - ) - - checkError( - exception = intercept[ExtendedAnalysisException] { - sql("SELECT map_sort(map(1,1,2,2), \"asc\")").collect() - }, - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"map_sort(map(1, 1, 2, 2), asc)\"", - "paramIndex" -> "second", - "inputSql" -> "\"asc\"", - "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"" - ), - queryContext = Array(ExpectedContext("", "", 7, 35, "map_sort(map(1,1,2,2), \"asc\")")) - ) - } - test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), From 2951bcc189ceee08526b3d119586aaa72028bb00 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 14:14:52 +0100 Subject: [PATCH 29/46] removed ExpressionDescription --- .../expressions/collectionOperations.scala | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3ed711d477621..98ba1cad68309 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -888,27 +888,6 @@ case class MapFromEntries(child: Expression) copy(child = newChild) } -@ExpressionDescription( - usage = """ - _FUNC_(map[, ascendingOrder]) - Sorts the input map in ascending or descending order - according to the natural ordering of the map keys. The algorithm used for sorting is - an adaptive, stable and iterative algorithm. If the input map is empty, function - returns an empty map. - """, - arguments = - """ - Arguments: - * map - The map that will be sorted. - * ascendingOrder - A boolean value describing the order in which the map will be sorted. - This can be either be ascending (true) or descending (false). - """, - examples = """ - Examples: - > SELECT _FUNC_(map(3, 'c', 1, 'a', 2, 'b'), true); - {1:"a",2:"b",3:"c"} - """, - group = "map_funcs", - since = "4.0.0") case class MapSort(base: Expression, ascendingOrder: Expression) extends BinaryExpression with NullIntolerant with QueryErrorsBase { From 0fc3c6a63b8257356fa4d782a62e158c5b6de914 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 14:43:15 +0100 Subject: [PATCH 30/46] Moved ordering outside of comapre function --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 98ba1cad68309..7de5eb755ebec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -980,6 +980,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val c = ctx.freshName("c") val newKeys = ctx.freshName("newKeys") val newValues = ctx.freshName("newValues") + val sortOrder = ctx.freshName("sortOrder") val boxedKeyType = CodeGenerator.boxedType(keyType) val boxedValueType = CodeGenerator.boxedType(valueType) @@ -1011,13 +1012,13 @@ case class MapSort(base: Expression, ascendingOrder: Expression) | ${CodeGenerator.getValue(keys, keyType, i)}, | ${CodeGenerator.getValue(values, valueType, i)}); |} - | + |final int $sortOrder = $order ? 1 : -1; |java.util.Arrays.sort($sortArray, new java.util.Comparator() { | @Override public int compare(Object $o1entry, Object $o2entry) { | Object $o1 = (($simpleEntryType) $o1entry).getKey(); | Object $o2 = (($simpleEntryType) $o2entry).getKey(); | $comp; - | return $order ? $c : -$c; + | return $sortOrder * $c; | } |}); | From 0c7d21a36e4e2eaee5ea4db67a59d69901c4d31e Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Thu, 21 Mar 2024 22:56:11 +0100 Subject: [PATCH 31/46] Removed oredering type --- .../expressions/collectionOperations.scala | 47 +++++-------------- .../CollectionExpressionsSuite.scala | 29 +++--------- 2 files changed, 19 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7de5eb755ebec..27225b4ac74a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -888,33 +888,19 @@ case class MapFromEntries(child: Expression) copy(child = newChild) } -case class MapSort(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with NullIntolerant with QueryErrorsBase { - - def this(e: Expression) = this(e, Literal(true)) +case class MapSort(base: Expression) + extends UnaryExpression with NullIntolerant with QueryErrorsBase { val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType - override def left: Expression = base - override def right: Expression = ascendingOrder + override def child: Expression = base + override def dataType: DataType = base.dataType override def checkInputDataTypes(): TypeCheckResult = base.dataType match { case m: MapType if RowOrdering.isOrderable(m.keyType) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> ordinalNumber(1), - "requiredType" -> toSQLType(BooleanType), - "inputSql" -> toSQLExpr(ascendingOrder), - "inputType" -> toSQLType(ascendingOrder.dataType)) - ) - } + TypeCheckResult.TypeCheckSuccess case _: MapType => DataTypeMismatch( errorSubClass = "INVALID_ORDERING_TYPE", @@ -934,7 +920,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) ) } - override def nullSafeEval(array: Any, ascending: Any): Any = { + override def nullSafeEval(array: Any): Any = { // put keys and their respective values inside a tuple and sort them // according to the key ordering. Extract the new sorted k/v pairs to form a sorted map @@ -943,11 +929,7 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val keys = mapData.keyArray() val values = mapData.valueArray() - val ordering = if (ascending.asInstanceOf[Boolean]) { - PhysicalDataType.ordering(keyType) - } else { - PhysicalDataType.ordering(keyType).reverse - } + val ordering = PhysicalDataType.ordering(keyType) val sortedMap = Array .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], @@ -959,11 +941,11 @@ case class MapSort(base: Expression, ascendingOrder: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + nullSafeCodeGen(ctx, ev, b => sortCodegen(ctx, ev, b)) } private def sortCodegen(ctx: CodegenContext, ev: ExprCode, - base: String, order: String): String = { + base: String): String = { val arrayBasedMapData = classOf[ArrayBasedMapData].getName val genericArrayData = classOf[GenericArrayData].getName @@ -980,7 +962,6 @@ case class MapSort(base: Expression, ascendingOrder: Expression) val c = ctx.freshName("c") val newKeys = ctx.freshName("newKeys") val newValues = ctx.freshName("newValues") - val sortOrder = ctx.freshName("sortOrder") val boxedKeyType = CodeGenerator.boxedType(keyType) val boxedValueType = CodeGenerator.boxedType(valueType) @@ -1012,13 +993,13 @@ case class MapSort(base: Expression, ascendingOrder: Expression) | ${CodeGenerator.getValue(keys, keyType, i)}, | ${CodeGenerator.getValue(values, valueType, i)}); |} - |final int $sortOrder = $order ? 1 : -1; + | |java.util.Arrays.sort($sortArray, new java.util.Comparator() { | @Override public int compare(Object $o1entry, Object $o2entry) { | Object $o1 = (($simpleEntryType) $o1entry).getKey(); | Object $o2 = (($simpleEntryType) $o2entry).getKey(); | $comp; - | return $sortOrder * $c; + | return $c; | } |}); | @@ -1035,10 +1016,8 @@ case class MapSort(base: Expression, ascendingOrder: Expression) |""".stripMargin } - override def prettyName: String = "map_sort" - - override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression) - : MapSort = copy(base = newLeft, ascendingOrder = newRight) + override protected def withNewChildInternal(newChild: Expression) + : MapSort = copy(base = newChild) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3063b83d4dca1..d14118eb3f1d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -434,31 +434,14 @@ class CollectionExpressionsSuite Map(create_row(2) -> 2, create_row(1) -> 1, create_row(3) -> 3), MapType(StructType(Seq(StructField("a", IntegerType))), IntegerType)) - checkEvaluation(new MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) - checkEvaluation(MapSort(intKey, Literal.create(false, BooleanType)), - Map(3 -> 3, 2 -> 2, 1 -> 1)) - - checkEvaluation(new MapSort(boolKey), Map(false -> 1, true -> 2)) - checkEvaluation(MapSort(boolKey, Literal.create(false, BooleanType)), - Map(true -> 2, false -> 1)) - - checkEvaluation(new MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) - checkEvaluation(MapSort(stringKey, Literal.create(false, BooleanType)), - Map("3" -> 3, "2" -> 2, "1" -> 1)) - - checkEvaluation(new MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) - checkEvaluation(MapSort(arrayKey, Literal.create(false, BooleanType)), - Map(Seq(3) -> 3, Seq(2) -> 2, Seq(1) -> 1)) - - checkEvaluation(new MapSort(nestedArrayKey), + checkEvaluation(MapSort(intKey), Map(1 -> 1, 2 -> 2, 3 -> 3)) + checkEvaluation(MapSort(boolKey), Map(false -> 1, true -> 2)) + checkEvaluation(MapSort(stringKey), Map("1" -> 1, "2" -> 2, "3" -> 3)) + checkEvaluation(MapSort(arrayKey), Map(Seq(1) -> 1, Seq(2) -> 2, Seq(3) -> 3)) + checkEvaluation(MapSort(nestedArrayKey), Map(Seq(Seq(1)) -> 1, Seq(Seq(2)) -> 2, Seq(Seq(3)) -> 3)) - checkEvaluation(MapSort(nestedArrayKey, Literal.create(false, BooleanType)), - Map(Seq(Seq(3)) -> 3, Seq(Seq(2)) -> 2, Seq(Seq(1)) -> 1)) - - checkEvaluation(new MapSort(structKey), + checkEvaluation(MapSort(structKey), Map(create_row(1) -> 1, create_row(2) -> 2, create_row(3) -> 3)) - checkEvaluation(MapSort(structKey, Literal.create(false, BooleanType)), - Map(create_row(3) -> 3, create_row(2) -> 2, create_row(1) -> 1)) } test("Sort Array") { From 671dabe0b1467b5fde2d7680a4d10d07c8f1b9c9 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Fri, 22 Mar 2024 09:47:32 +0100 Subject: [PATCH 32/46] Removed second parameter from MapSort expression invocation --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index dd7919314e9bb..337e1d0fc7385 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -2485,7 +2485,7 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { (expr, expr.dataType) match { case (_: MapSort, _) => expr case (_, _: MapType) => - MapSort(expr, Literal.TrueLiteral) + MapSort(expr) case _ => expr } } From 7cc928a700f06113557f9f3b7d056ce48bedb7a3 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Fri, 22 Mar 2024 10:33:49 +0100 Subject: [PATCH 33/46] Fix scalastyles issues --- .../sql/catalyst/optimizer/NormalizeFloatingNumbers.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 40364d3fc4bd9..fd6b5d56485b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -151,9 +151,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { KnownFloatingPointNormalized( ArrayTransform( ArrayTransform(expr, LambdaFunction(functionL1, Seq(lv1))), - LambdaFunction(functionL2, Seq(lv2)), - ) - ) + LambdaFunction(functionL2, Seq(lv2)))) case _ => throw SparkException.internalError(s"fail to normalize $expr") } From d84b2b537d70153aa62acaeefc6d12d39e7ec5e2 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Fri, 22 Mar 2024 15:21:37 +0100 Subject: [PATCH 34/46] Updated normalised functionality --- .../optimizer/NormalizeFloatingNumbers.scala | 17 +++++++---------- .../spark/sql/DataFrameAggregateSuite.scala | 8 ++++++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index fd6b5d56485b8..2fcc689b9df2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, TransformValues, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window} @@ -98,7 +98,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) case ArrayType(et, _) => needNormalize(et) - case MapType(kt, vt, _) => needNormalize(kt) && needNormalize(vt) + case MapType(_, vt, _) => needNormalize(vt) case _ => false } @@ -144,14 +144,11 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case _ if expr.dataType.isInstanceOf[MapType] => val MapType(kt, vt, containsNull) = expr.dataType - val lv1 = NamedLambdaVariable("arg", kt, containsNull) - val lv2 = NamedLambdaVariable("arg", vt, containsNull) - val functionL1 = normalize(lv1) - val functionL2 = normalize(lv2) - KnownFloatingPointNormalized( - ArrayTransform( - ArrayTransform(expr, LambdaFunction(functionL1, Seq(lv1))), - LambdaFunction(functionL2, Seq(lv2)))) + val keys = NamedLambdaVariable("arg", kt, containsNull) + val values = NamedLambdaVariable("arg", vt, containsNull) + val function = normalize(values) + KnownFloatingPointNormalized(TransformValues(expr, + LambdaFunction(function, Seq(keys, values)))) case _ => throw SparkException.internalError(s"fail to normalize $expr") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 294d3e79f2ab5..f5e30a2303e5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2182,15 +2182,19 @@ class DataFrameAggregateSuite extends QueryTest test("SPARK-47430 Support GROUP BY MapType") { val numRows = 50 - val dfSame = (0 until numRows) + val dfSameInt = (0 until numRows) .map(_ => Tuple1(Map(1 -> 1))) .toDF("m0") + val dfSameFloat = (0 until numRows) + .map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 ))) + .toDF("m0") + val dfDifferent = (0 until numRows) .map(i => Tuple1(Map(i -> i))) .toDF("m0") - assertAggregateOnDataframe(Seq(dfSame, dfDifferent), Seq(1, numRows), "m0") + assertAggregateOnDataframe(Seq(dfSameInt, dfSameFloat, dfDifferent), Seq(1, 1, numRows), "m0") } test("SPARK-46536 Support GROUP BY CalendarIntervalType") { From 9f1035df915221be04c0c531f20f719ea0061848 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Fri, 22 Mar 2024 15:26:34 +0100 Subject: [PATCH 35/46] Refactored aggregation rule in optimizer --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 337e1d0fc7385..41bf61a99f212 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -2480,13 +2480,12 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(AGGREGATE), ruleId) { - case a @ Aggregate(groupingExpr, x, b) => + case a @ Aggregate(groupingExpr, _, _) => val newGrouping = groupingExpr.map { expr => - (expr, expr.dataType) match { - case (_: MapSort, _) => expr - case (_, _: MapType) => - MapSort(expr) - case _ => expr + if (!expr.isInstanceOf[MapSort] && expr.dataType.isInstanceOf[MapType]) { + MapSort(expr) + } else { + expr } } a.copy(groupingExpressions = newGrouping) From d184d4891c767e26dd25ed007a648a4381b44ba3 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 25 Mar 2024 13:01:42 +0100 Subject: [PATCH 36/46] Removed GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE error and fixed a test --- .../main/resources/error/error-classes.json | 6 -- .../analysis/AnalysisErrorSuite.scala | 64 ++----------------- 2 files changed, 7 insertions(+), 63 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 091f24d44f66c..ef3d113e914d6 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -1373,12 +1373,6 @@ ], "sqlState" : "42805" }, - "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE" : { - "message" : [ - "The expression cannot be used as a grouping expression because its data type is not an orderable data type." - ], - "sqlState" : "42822" - }, "HLL_INVALID_INPUT_SKETCH_BUFFER" : { "message" : [ "Invalid call to ; only valid HLL sketch buffers are supported as inputs (such as those produced by the `hll_sketch_agg` function)." diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8366e8a22d428..5b4c2c0bc8fe2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.Assertions._ - import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -28,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Max} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{AsOfJoinDirection, Cross, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.errors.DataTypeErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -59,32 +56,6 @@ private[sql] case class UngroupableData(data: Map[Int, Int]) { def getData: Map[Int, Int] = data } -private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { - - override def sqlType: DataType = MapType(IntegerType, IntegerType) - - override def serialize(ungroupableData: UngroupableData): MapData = { - val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq) - val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq) - new ArrayBasedMapData(keyArray, valueArray) - } - - override def deserialize(datum: Any): UngroupableData = { - datum match { - case data: MapData => - val keyArray = data.keyArray().array - val valueArray = data.valueArray().array - assert(keyArray.length == valueArray.length) - val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]] - UngroupableData(mapData) - } - } - - override def userClass: Class[UngroupableData] = classOf[UngroupableData] - - private[spark] override def asNullable: UngroupableUDT = this -} - case class TestFunction( children: Seq[Expression], inputTypes: Seq[AbstractDataType]) @@ -1005,8 +976,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("check grouping expression data types") { - def checkDataType( - dataType: DataType, shouldSuccess: Boolean, dataTypeMsg: String = ""): Unit = { + def checkDataType(dataType: DataType): Unit = { val plan = Aggregate( AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil, @@ -1015,18 +985,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { AttributeReference("a", dataType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - if (shouldSuccess) { - assertAnalysisSuccess(plan, true) - } else { - assertAnalysisErrorClass( - inputPlan = plan, - expectedErrorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE", - expectedMessageParameters = Map( - "sqlExpr" -> "\"a\"", - "dataType" -> dataTypeMsg - ) - ) - } + assertAnalysisSuccess(plan, true) } val supportedDataTypes = Seq( @@ -1036,6 +995,10 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", MapType(StringType, LongType), nullable = true), new StructType() .add("f1", FloatType, nullable = true) .add("f2", StringType, nullable = true), @@ -1044,20 +1007,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), new GroupableUDT()) supportedDataTypes.foreach { dataType => - checkDataType(dataType, shouldSuccess = true) - } - - val unsupportedDataTypes = Seq( - MapType(StringType, LongType), - new StructType() - .add("f1", FloatType, nullable = true) - .add("f2", MapType(StringType, LongType), nullable = true), - new UngroupableUDT()) - val expectedDataTypeParameters = - Seq("\"MAP\"", "\"STRUCT>\"") - unsupportedDataTypes.zip(expectedDataTypeParameters).foreach { - case (dataType, dataTypeMsg) => - checkDataType(dataType, shouldSuccess = false, dataTypeMsg) + checkDataType(dataType) } } From 04d68cc0f5e548fbd5e5764055d05e1409f132f7 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 25 Mar 2024 13:06:07 +0100 Subject: [PATCH 37/46] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala Co-authored-by: Wenchen Fan --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1f18869b5b5d5..f9dbff54d6453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -732,8 +732,10 @@ class CodegenContext extends Logging { throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError("compare", dataType) } - private def genCompMapData(keyType: DataType, - valueType: DataType, compareFunc : String): String = { + private def genCompMapData( + keyType: DataType, + valueType: DataType, + compareFunc : String): String = { val keyArrayA = freshName("keyArrayA") val keyArrayB = freshName("keyArrayB") val valueArrayA = freshName("valueArrayA") From ebb3325532db263c8186d2a169b3762378bef778 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 25 Mar 2024 13:13:36 +0100 Subject: [PATCH 38/46] added scala stripMargin identation control --- .../expressions/codegen/CodeGenerator.scala | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f9dbff54d6453..bf6a6c7f65a9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -742,27 +742,27 @@ class CodegenContext extends Logging { val valueArrayB = freshName("valueArrayB") val minLength = freshName("minLength") s""" - public int $compareFunc(MapData a, MapData b) { - int lengthA = a.numElements(); - int lengthB = b.numElements(); - ArrayData $keyArrayA = a.keyArray(); - ArrayData $valueArrayA = a.valueArray(); - ArrayData $keyArrayB = b.keyArray(); - ArrayData $valueArrayB = b.valueArray(); - int $minLength = (lengthA > lengthB) ? lengthB : lengthA; - for (int i = 0; i < $minLength; i++) { - ${genCompElementsAt(keyArrayA, keyArrayB, "i", keyType)} - ${genCompElementsAt(valueArrayA, valueArrayB, "i", valueType)} - } - - if (lengthA < lengthB) { - return -1; - } else if (lengthA > lengthB) { - return 1; - } - return 0; - } - """ + |public int $compareFunc(MapData a, MapData b) { + | int lengthA = a.numElements(); + | int lengthB = b.numElements(); + | ArrayData $keyArrayA = a.keyArray(); + | ArrayData $valueArrayA = a.valueArray(); + | ArrayData $keyArrayB = b.keyArray(); + | ArrayData $valueArrayB = b.valueArray(); + | int $minLength = (lengthA > lengthB) ? lengthB : lengthA; + | for (int i = 0; i < $minLength; i++) { + | ${genCompElementsAt(keyArrayA, keyArrayB, "i", keyType)} + | ${genCompElementsAt(valueArrayA, valueArrayB, "i", valueType)} + | } + | + | if (lengthA < lengthB) { + | return -1; + | } else if (lengthA > lengthB) { + | return 1; + | } + | return 0; + |} + """.stripMargin } private def genCompElementsAt(arrayA: String, arrayB: String, i: String, @@ -773,23 +773,23 @@ class CodegenContext extends Logging { val isNullB = freshName("isNullB") val jt = javaType(elementType); s""" - boolean $isNullA = $arrayA.isNullAt($i); - boolean $isNullB = $arrayB.isNullAt($i); - if ($isNullA && $isNullB) { - // Nothing - } else if ($isNullA) { - return -1; - } else if ($isNullB) { - return 1; - } else { - $jt $elementA = ${getValue(arrayA, elementType, i)}; - $jt $elementB = ${getValue(arrayB, elementType, i)}; - int comp = ${genComp(elementType, elementA, elementB)}; - if (comp != 0) { - return comp; - } - } - """ + |boolean $isNullA = $arrayA.isNullAt($i); + |boolean $isNullB = $arrayB.isNullAt($i); + |if ($isNullA && $isNullB) { + | // Nothing + |} else if ($isNullA) { + | return -1; + |} else if ($isNullB) { + | return 1; + |} else { + | $jt $elementA = ${getValue(arrayA, elementType, i)}; + | $jt $elementB = ${getValue(arrayB, elementType, i)}; + | int comp = ${genComp(elementType, elementA, elementB)}; + | if (comp != 0) { + | return comp; + | } + |} + """.stripMargin } /** From 185f7f1f5b1a24f642bab79f9f042a5810217295 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 25 Mar 2024 13:21:26 +0100 Subject: [PATCH 39/46] Replaced comparison in array with genCompElementsAt --- .../expressions/codegen/CodeGenerator.scala | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bf6a6c7f65a9c..1d8e7d3460b5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -660,13 +660,8 @@ class CodegenContext extends Logging { case NullType => "0" case array: ArrayType => val elementType = array.elementType - val elementA = freshName("elementA") - val isNullA = freshName("isNullA") - val elementB = freshName("elementB") - val isNullB = freshName("isNullB") val compareFunc = freshName("compareArray") val minLength = freshName("minLength") - val jt = javaType(elementType) val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { @@ -679,22 +674,7 @@ class CodegenContext extends Logging { int lengthB = b.numElements(); int $minLength = (lengthA > lengthB) ? lengthB : lengthA; for (int i = 0; i < $minLength; i++) { - boolean $isNullA = a.isNullAt(i); - boolean $isNullB = b.isNullAt(i); - if ($isNullA && $isNullB) { - // Nothing - } else if ($isNullA) { - return -1; - } else if ($isNullB) { - return 1; - } else { - $jt $elementA = ${getValue("a", elementType, "i")}; - $jt $elementB = ${getValue("b", elementType, "i")}; - int comp = ${genComp(elementType, elementA, elementB)}; - if (comp != 0) { - return comp; - } - } + ${genCompElementsAt("a", "b", "i", elementType)} } if (lengthA < lengthB) { From 3137f6a1ad5483c27411986b657c336fbc08f6fd Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Mon, 25 Mar 2024 13:39:12 +0100 Subject: [PATCH 40/46] Refactor optimizer rule for InsertMapSortInGroupingExpressions --- .../InsertMapSortInGroupingExpressions.scala | 45 +++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 26 ++--------- 2 files changed, 48 insertions(+), 23 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala new file mode 100644 index 0000000000000..3d883fa9d9ae2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.MapSort +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE +import org.apache.spark.sql.types.MapType + +/** + * Adds MapSort to group expressions containing map columns, as the key/value paris need to be + * in the correct order before grouping: + * SELECT COUNT(*) FROM TABLE GROUP BY map_column => + * SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column) + */ +object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(AGGREGATE), ruleId) { + case a @ Aggregate(groupingExpr, _, _) => + val newGrouping = groupingExpr.map { expr => + if (!expr.isInstanceOf[MapSort] && expr.dataType.isInstanceOf[MapType]) { + MapSort(expr) + } else { + expr + } + } + a.copy(groupingExpressions = newGrouping) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 41bf61a99f212..835a93f81777c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -196,7 +196,6 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceDeduplicateWithAggregate) :: Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions, - InsertMapSortInGroupingExpressions, RemoveRepetitionFromGroupExpressions) :: Nil ++ operatorOptimizationBatch) :+ Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo) :+ @@ -245,7 +244,9 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveRedundantAliases, RemoveNoopOperators) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. - Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ + Batch("NormalizeFloatingNumbers", Once, + InsertMapSortInGroupingExpressions, + NormalizeFloatingNumbers) :+ Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) // remove any batches with no rules. this may happen when subclasses do not add optional rules. @@ -2470,24 +2471,3 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } } } - -/** - * Adds MapSort to group expressions containing map columns, as the key/value paris need to be - * in the correct order before grouping: - * SELECT COUNT(*) FROM TABLE GROUP BY map_column => - * SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column) - */ -object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsPattern(AGGREGATE), ruleId) { - case a @ Aggregate(groupingExpr, _, _) => - val newGrouping = groupingExpr.map { expr => - if (!expr.isInstanceOf[MapSort] && expr.dataType.isInstanceOf[MapType]) { - MapSort(expr) - } else { - expr - } - } - a.copy(groupingExpressions = newGrouping) - } -} From 14fdcd2bfb9f4bb94982479e09a328a4a268dcd3 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 11:22:31 +0100 Subject: [PATCH 41/46] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala Co-authored-by: Wenchen Fan --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1d8e7d3460b5c..d0bf1f7ed8e69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -715,7 +715,7 @@ class CodegenContext extends Logging { private def genCompMapData( keyType: DataType, valueType: DataType, - compareFunc : String): String = { + compareFunc: String): String = { val keyArrayA = freshName("keyArrayA") val keyArrayB = freshName("keyArrayB") val valueArrayA = freshName("valueArrayA") From 7fe7b7e409d5cc5c5242e972dcc812006312a3c4 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 11:31:42 +0100 Subject: [PATCH 42/46] Refactored code-gen and separated optimizer rule in separate batch --- .../expressions/codegen/CodeGenerator.scala | 21 +++++++------------ .../sql/catalyst/optimizer/Optimizer.scala | 6 +++--- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d0bf1f7ed8e69..d123a2460c493 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -716,23 +716,18 @@ class CodegenContext extends Logging { keyType: DataType, valueType: DataType, compareFunc: String): String = { - val keyArrayA = freshName("keyArrayA") - val keyArrayB = freshName("keyArrayB") - val valueArrayA = freshName("valueArrayA") - val valueArrayB = freshName("valueArrayB") - val minLength = freshName("minLength") s""" |public int $compareFunc(MapData a, MapData b) { | int lengthA = a.numElements(); | int lengthB = b.numElements(); - | ArrayData $keyArrayA = a.keyArray(); - | ArrayData $valueArrayA = a.valueArray(); - | ArrayData $keyArrayB = b.keyArray(); - | ArrayData $valueArrayB = b.valueArray(); - | int $minLength = (lengthA > lengthB) ? lengthB : lengthA; - | for (int i = 0; i < $minLength; i++) { - | ${genCompElementsAt(keyArrayA, keyArrayB, "i", keyType)} - | ${genCompElementsAt(valueArrayA, valueArrayB, "i", valueType)} + | ArrayData keyArrayA = a.keyArray(); + | ArrayData valueArrayA = a.valueArray(); + | ArrayData keyArrayB = b.keyArray(); + | ArrayData valueArrayB = b.valueArray(); + | int minLength = (lengthA > lengthB) ? lengthB : lengthA; + | for (int i = 0; i < minLength; i++) { + | ${genCompElementsAt("keyArrayA", "keyArrayB", "i", keyType)} + | ${genCompElementsAt("valueArrayA", "valueArrayB", "i", valueType)} | } | | if (lengthA < lengthB) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 835a93f81777c..6461de7ab2343 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -243,10 +243,10 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseProject, RemoveRedundantAliases, RemoveNoopOperators) :+ + Batch("InsertMapSortInGroupingExpressions", Once, + InsertMapSortInGroupingExpressions) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. - Batch("NormalizeFloatingNumbers", Once, - InsertMapSortInGroupingExpressions, - NormalizeFloatingNumbers) :+ + Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) // remove any batches with no rules. this may happen when subclasses do not add optional rules. From 807604539252f683526affea8e83b6bae9cbe889 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 11:36:55 +0100 Subject: [PATCH 43/46] refactored tests --- .../spark/sql/DataFrameAggregateSuite.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f5e30a2303e5e..aad62ad1b3834 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2155,8 +2155,8 @@ class DataFrameAggregateSuite extends QueryTest ) } - private def assertAggregateOnDataframe(dfSeq: Seq[DataFrame], - expected: Seq[Int], aggregateColumn: String): Unit = { + private def assertAggregateOnDataframe(df: DataFrame, + expected: Int, aggregateColumn: String): Unit = { val configurations = Seq( Seq.empty[(String, String)], // hash aggregate is used by default Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN", @@ -2168,11 +2168,9 @@ class DataFrameAggregateSuite extends QueryTest Seq("spark.sql.test.forceApplySortAggregate" -> "true") ) - for ((df, index) <- dfSeq.zipWithIndex) { - for (conf <- configurations) { - withSQLConf(conf: _*) { - assert(createAggregate(df).count() == expected(index)) - } + for (conf <- configurations) { + withSQLConf(conf: _*) { + assert(createAggregate(df).count() == expected) } } @@ -2185,16 +2183,17 @@ class DataFrameAggregateSuite extends QueryTest val dfSameInt = (0 until numRows) .map(_ => Tuple1(Map(1 -> 1))) .toDF("m0") + assertAggregateOnDataframe(dfSameInt, 1, "m0") val dfSameFloat = (0 until numRows) .map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 ))) .toDF("m0") + assertAggregateOnDataframe(dfSameInt, 1, "m0") val dfDifferent = (0 until numRows) .map(i => Tuple1(Map(i -> i))) .toDF("m0") - - assertAggregateOnDataframe(Seq(dfSameInt, dfSameFloat, dfDifferent), Seq(1, 1, numRows), "m0") + assertAggregateOnDataframe(dfSameInt, numRows, "m0") } test("SPARK-46536 Support GROUP BY CalendarIntervalType") { @@ -2203,12 +2202,12 @@ class DataFrameAggregateSuite extends QueryTest val dfSame = (0 until numRows) .map(_ => Tuple1(new CalendarInterval(1, 2, 3))) .toDF("c0") + assertAggregateOnDataframe(dfSame, 1, "c0") val dfDifferent = (0 until numRows) .map(i => Tuple1(new CalendarInterval(i, i, i))) .toDF("c0") - - assertAggregateOnDataframe(Seq(dfSame, dfDifferent), Seq(1, numRows), "c0") + assertAggregateOnDataframe(dfDifferent, numRows, "c0") } test("SPARK-46779: Group by subquery with a cached relation") { From 3c573e046a93d32a1687777e1a67e9c26062b6ae Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 16:17:07 +0100 Subject: [PATCH 44/46] Regenerated sql-error-conditions.md --- docs/sql-error-conditions.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index bab64caa38886..3fb6f19942b3b 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -852,12 +852,6 @@ GROUP BY `` refers to an expression `` that contains an aggregat GROUP BY position `` is not in select list (valid range is [1, ``]). -### GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE - -[SQLSTATE: 42822](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) - -The expression `` cannot be used as a grouping expression because its data type `` is not an orderable data type. - ### HLL_INVALID_INPUT_SKETCH_BUFFER [SQLSTATE: 22546](sql-error-conditions-sqlstates.html#class-22-data-exception) From c6050c0892f52f9cfc12993599aeac098fda651e Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 16:53:16 +0100 Subject: [PATCH 45/46] Removed a test that checks for Map as an invalid grouping type --- .../catalyst/optimizer/OptimizerSuite.scala | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala index 590fb323000b9..48cdbbe7be539 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala @@ -103,30 +103,6 @@ class OptimizerSuite extends PlanTest { assert(message1.contains("are dangling")) } - test("Optimizer per rule validation catches invalid grouping types") { - val analyzed = LocalRelation(Symbol("a").map(IntegerType, IntegerType)) - .select(Symbol("a")).analyze - - /** - * A dummy optimizer rule for testing that invalid grouping types are not allowed. - */ - object InvalidGroupingType extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - Aggregate(plan.output, plan.output, plan) - } - } - - val optimizer = new SimpleTestOptimizer() { - override def defaultBatches: Seq[Batch] = - Batch("test", FixedPoint(1), - InvalidGroupingType) :: Nil - } - val message1 = intercept[SparkException] { - optimizer.execute(analyzed) - }.getMessage - assert(message1.contains("cannot be of type Map")) - } - test("Optimizer per rule validation catches invalid aggregation expressions") { val analyzed = LocalRelation(Symbol("a").long, Symbol("b").long) .select(Symbol("a"), Symbol("b")).analyze From 3eac76c49cfc88beff3b0057c9dd563dbb95c92f Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Wed, 27 Mar 2024 09:26:12 +0100 Subject: [PATCH 46/46] Fixed map-group-by test --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index aad62ad1b3834..7c81ac7244878 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2188,12 +2188,12 @@ class DataFrameAggregateSuite extends QueryTest val dfSameFloat = (0 until numRows) .map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 ))) .toDF("m0") - assertAggregateOnDataframe(dfSameInt, 1, "m0") + assertAggregateOnDataframe(dfSameFloat, 1, "m0") val dfDifferent = (0 until numRows) .map(i => Tuple1(Map(i -> i))) .toDF("m0") - assertAggregateOnDataframe(dfSameInt, numRows, "m0") + assertAggregateOnDataframe(dfDifferent, numRows, "m0") } test("SPARK-46536 Support GROUP BY CalendarIntervalType") {