From 0aef4f8bc6197fb6660e993555b03e6a27542772 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Mon, 9 Apr 2018 16:12:16 +0200 Subject: [PATCH] [SPARK-23926][SQL] Extending reverse function to support ArrayType arguments. --- python/pyspark/sql/functions.py | 20 +++- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/collectionOperations.scala | 88 +++++++++++++++++ .../expressions/stringExpressions.scala | 20 ---- .../CollectionExpressionsSuite.scala | 44 +++++++++ .../expressions/StringExpressionsSuite.scala | 6 +- .../org/apache/spark/sql/functions.scala | 15 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 94 +++++++++++++++++++ 8 files changed, 256 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 52fab7c086efa..a39e2eb2f3254 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1358,7 +1358,6 @@ def hash(*cols): 'uppercase. Words are delimited by whitespace.', 'lower': 'Converts a string column to lower case.', 'upper': 'Converts a string column to upper case.', - 'reverse': 'Reverses the string column and returns it as a new string column.', 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.', @@ -2042,6 +2041,25 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(1.5) +@ignore_unicode_prefix +def reverse(col): + """ + Collection function: returns a reversed string or an array with reverse order of elements. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) + >>> df.select(reverse(df.data).alias('s')).collect() + [Row(s=u'LQS krapS')] + >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) + >>> df.select(reverse(df.data).alias('r')).collect() + [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.reverse(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..81bb2513ac82e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -336,7 +336,6 @@ object FunctionRegistry { expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReplace]("replace"), - expression[StringReverse]("reverse"), expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), @@ -408,6 +407,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[Reverse]("reverse"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4270b987d6de0..c4bfb8fce8fdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } +/** + * Returns a reversed string or an array with reverse order of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + LQS krapS + > SELECT _FUNC_(array(2, 1, 4, 3)); + [3, 4, 1, 2] + """, + since = "1.5.0", + note = "Reverse logic for arrays is available since 2.4.0." +) +case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Input types are utilized by type coercion in ImplicitTypeCasts. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + + override def dataType: DataType = child.dataType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(input: Any): Any = input match { + case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) + case s: UTF8String => s.reverse() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => dataType match { + case _: StringType => stringCodeGen(ev, c) + case _: ArrayType => arrayCodeGen(ctx, ev, c) + }) + } + + private def stringCodeGen(ev: ExprCode, childName: String): String = { + s"${ev.value} = ($childName).reverse();" + } + + private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val length = ctx.freshName("length") + val javaElementType = ctx.javaType(elementType) + val isPrimitiveType = ctx.isPrimitiveType(elementType) + + val initialization = if (isPrimitiveType) { + s"$childName.copy()" + } else { + s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + } + + val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length + + val swapAssigments = if (isPrimitiveType) { + val setFunc = "set" + ctx.primitiveTypeName(elementType) + val getCall = (index: String) => ctx.getValue(ev.value, elementType, index) + s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); + |boolean isNullAtL = ${ev.value}.isNullAt(l); + |if(!isNullAtK) { + | $javaElementType el = ${getCall("k")}; + | if(!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | } else { + | ${ev.value}.setNullAt(k); + | } + | ${ev.value}.$setFunc(l, el); + |} else if (!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | ${ev.value}.setNullAt(l); + |}""".stripMargin + } else { + s"${ev.value}.update(k, ${ctx.getValue(childName, elementType, "l")});" + } + + s""" + |final int $length = $childName.numElements(); + |${ev.value} = $initialization; + |for(int k = 0; k < $numberOfIterations; k++) { + | int l = $length - k - 1; + | $swapAssigments + |} + """.stripMargin + } + + override def prettyName: String = "reverse" +} + /** * Checks if the array (left) has the element (right) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d7612e30b4c57..fd5e01b1af524 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression) } } -/** - * Returns the reversed given string. - */ -@ExpressionDescription( - usage = "_FUNC_(str) - Returns the reversed given string.", - examples = """ - Examples: - > SELECT _FUNC_('Spark SQL'); - LQS krapS - """) -case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { - override def convert(v: UTF8String): UTF8String = v.reverse() - - override def prettyName: String = "reverse" - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).reverse()") - } -} - /** * Returns a string consisting of n spaces. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..ac1c18ecf625a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Reverse") { + // Primitive-type elements + val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType)) + val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) + val ai5 = Literal.create(Seq(1), ArrayType(IntegerType)) + val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType)) + val ai7 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2)) + checkEvaluation(Reverse(ai1), Seq(3, 1, 2)) + checkEvaluation(Reverse(ai2), Seq(3, null, 1, null)) + checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2)) + checkEvaluation(Reverse(ai4), Seq(null, null, null)) + checkEvaluation(Reverse(ai5), Seq(1)) + checkEvaluation(Reverse(ai6), Seq.empty) + checkEvaluation(Reverse(ai7), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) + val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType)) + val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType)) + val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType)) + val as5 = Literal.create(Seq("a"), ArrayType(StringType)) + val as6 = Literal.create(Seq.empty, ArrayType(StringType)) + val as7 = Literal.create(null, ArrayType(StringType)) + val aa = Literal.create( + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b")) + checkEvaluation(Reverse(as1), Seq("c", "a", "b")) + checkEvaluation(Reverse(as2), Seq("c", null, "a", null)) + checkEvaluation(Reverse(as3), Seq(null, "d", null, "b")) + checkEvaluation(Reverse(as4), Seq(null, null, null)) + checkEvaluation(Reverse(as5), Seq("a")) + checkEvaluation(Reverse(as6), Seq.empty) + checkEvaluation(Reverse(as7), null) + checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 97ddbeba2c5ca..ae1065786bce8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("REVERSE") { val s = 'a.string.at(0) val row1 = create_row("abccc") - checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) - checkEvaluation(StringReverse(s), "cccba", row1) - checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) + checkEvaluation(Reverse(Literal("abccc")), "cccba", row1) + checkEvaluation(Reverse(s), "cccba", row1) + checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d54c02c3d06f..f6a87187b0fe9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2406,14 +2406,6 @@ object functions { StringRepeat(str.expr, lit(n).expr) } - /** - * Reverses the string column and returns it as a new string column. - * - * @group string_funcs - * @since 1.5.0 - */ - def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } - /** * Trim the spaces from right end for the specified string value. * @@ -3242,6 +3234,13 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns a reversed string or an array with reverse order of elements. + * @group collection_funcs + * @since 1.5.0 + */ + def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..2103e12b566a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("reverse function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on + + // String test cases + val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + + checkAnswer( + oneRowDF.select(reverse('s)), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(s)"), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.select(reverse('i)), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(i)"), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(null)"), + Seq(Row(null)) + ) + + // Array test cases (primitive-type elements) + val idf = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + checkAnswer( + idf.select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.filter(dummyFilter('i)).select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("reverse(i)"), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + + // Array test cases (non-primitive-type elements) + val sdf = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + checkAnswer( + sdf.select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.filter(dummyFilter('s)).select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(s)"), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(struct(1, 'a'))") + } + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(map(1, 'a'))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {