Skip to content

Commit

Permalink
[SPARK-23926][SQL] Extending reverse function to support ArrayType ar…
Browse files Browse the repository at this point in the history
…guments.
  • Loading branch information
mn-mikke authored and mn-mikke committed Apr 17, 2018
1 parent 9c43c96 commit 0aef4f8
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 33 deletions.
20 changes: 19 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,6 @@ def hash(*cols):
'uppercase. Words are delimited by whitespace.',
'lower': 'Converts a string column to lower case.',
'upper': 'Converts a string column to upper case.',
'reverse': 'Reverses the string column and returns it as a new string column.',
'ltrim': 'Trim the spaces from left end for the specified string value.',
'rtrim': 'Trim the spaces from right end for the specified string value.',
'trim': 'Trim the spaces from both ends for the specified string column.',
Expand Down Expand Up @@ -2042,6 +2041,25 @@ def sort_array(col, asc=True):
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))


@since(1.5)
@ignore_unicode_prefix
def reverse(col):
"""
Collection function: returns a reversed string or an array with reverse order of elements.
:param col: name of column or expression
>>> df = spark.createDataFrame([('Spark SQL',)], ['data'])
>>> df.select(reverse(df.data).alias('s')).collect()
[Row(s=u'LQS krapS')]
>>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])
>>> df.select(reverse(df.data).alias('r')).collect()
[Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.reverse(_to_java_column(col)))


@since(2.3)
def map_keys(col):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ object FunctionRegistry {
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
expression[StringReverse]("reverse"),
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
Expand Down Expand Up @@ -408,6 +407,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[Reverse]("reverse"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Given an array or map, returns its size. Returns -1 if null.
Expand Down Expand Up @@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
override def prettyName: String = "sort_array"
}

/**
* Returns a reversed string or an array with reverse order of elements.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.",
examples = """
Examples:
> SELECT _FUNC_('Spark SQL');
LQS krapS
> SELECT _FUNC_(array(2, 1, 4, 3));
[3, 4, 1, 2]
""",
since = "1.5.0",
note = "Reverse logic for arrays is available since 2.4.0."
)
case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))

override def dataType: DataType = child.dataType

lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def nullSafeEval(input: Any): Any = input match {
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
case s: UTF8String => s.reverse()
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => dataType match {
case _: StringType => stringCodeGen(ev, c)
case _: ArrayType => arrayCodeGen(ctx, ev, c)
})
}

private def stringCodeGen(ev: ExprCode, childName: String): String = {
s"${ev.value} = ($childName).reverse();"
}

private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
val length = ctx.freshName("length")
val javaElementType = ctx.javaType(elementType)
val isPrimitiveType = ctx.isPrimitiveType(elementType)

val initialization = if (isPrimitiveType) {
s"$childName.copy()"
} else {
s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
}

val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length

val swapAssigments = if (isPrimitiveType) {
val setFunc = "set" + ctx.primitiveTypeName(elementType)
val getCall = (index: String) => ctx.getValue(ev.value, elementType, index)
s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
|boolean isNullAtL = ${ev.value}.isNullAt(l);
|if(!isNullAtK) {
| $javaElementType el = ${getCall("k")};
| if(!isNullAtL) {
| ${ev.value}.$setFunc(k, ${getCall("l")});
| } else {
| ${ev.value}.setNullAt(k);
| }
| ${ev.value}.$setFunc(l, el);
|} else if (!isNullAtL) {
| ${ev.value}.$setFunc(k, ${getCall("l")});
| ${ev.value}.setNullAt(l);
|}""".stripMargin
} else {
s"${ev.value}.update(k, ${ctx.getValue(childName, elementType, "l")});"
}

s"""
|final int $length = $childName.numElements();
|${ev.value} = $initialization;
|for(int k = 0; k < $numberOfIterations; k++) {
| int l = $length - k - 1;
| $swapAssigments
|}
""".stripMargin
}

override def prettyName: String = "reverse"
}

/**
* Checks if the array (left) has the element (right)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression)
}
}

/**
* Returns the reversed given string.
*/
@ExpressionDescription(
usage = "_FUNC_(str) - Returns the reversed given string.",
examples = """
Examples:
> SELECT _FUNC_('Spark SQL');
LQS krapS
""")
case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.reverse()

override def prettyName: String = "reverse"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).reverse()")
}
}

/**
* Returns a string consisting of n spaces.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("Reverse") {
// Primitive-type elements
val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType))
val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType))
val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType))
val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType))
val ai5 = Literal.create(Seq(1), ArrayType(IntegerType))
val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType))
val ai7 = Literal.create(null, ArrayType(IntegerType))

checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2))
checkEvaluation(Reverse(ai1), Seq(3, 1, 2))
checkEvaluation(Reverse(ai2), Seq(3, null, 1, null))
checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2))
checkEvaluation(Reverse(ai4), Seq(null, null, null))
checkEvaluation(Reverse(ai5), Seq(1))
checkEvaluation(Reverse(ai6), Seq.empty)
checkEvaluation(Reverse(ai7), null)

// Non-primitive-type elements
val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType))
val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType))
val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType))
val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType))
val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType))
val as5 = Literal.create(Seq("a"), ArrayType(StringType))
val as6 = Literal.create(Seq.empty, ArrayType(StringType))
val as7 = Literal.create(null, ArrayType(StringType))
val aa = Literal.create(
Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")),
ArrayType(ArrayType(StringType)))

checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b"))
checkEvaluation(Reverse(as1), Seq("c", "a", "b"))
checkEvaluation(Reverse(as2), Seq("c", null, "a", null))
checkEvaluation(Reverse(as3), Seq(null, "d", null, "b"))
checkEvaluation(Reverse(as4), Seq(null, null, null))
checkEvaluation(Reverse(as5), Seq("a"))
checkEvaluation(Reverse(as6), Seq.empty)
checkEvaluation(Reverse(as7), null)
checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("REVERSE") {
val s = 'a.string.at(0)
val row1 = create_row("abccc")
checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
checkEvaluation(StringReverse(s), "cccba", row1)
checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
checkEvaluation(Reverse(Literal("abccc")), "cccba", row1)
checkEvaluation(Reverse(s), "cccba", row1)
checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1)
}

test("SPACE") {
Expand Down
15 changes: 7 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2406,14 +2406,6 @@ object functions {
StringRepeat(str.expr, lit(n).expr)
}

/**
* Reverses the string column and returns it as a new string column.
*
* @group string_funcs
* @since 1.5.0
*/
def reverse(str: Column): Column = withExpr { StringReverse(str.expr) }

/**
* Trim the spaces from right end for the specified string value.
*
Expand Down Expand Up @@ -3242,6 +3234,13 @@ object functions {
*/
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }

/**
* Returns a reversed string or an array with reverse order of elements.
* @group collection_funcs
* @since 1.5.0
*/
def reverse(e: Column): Column = withExpr { Reverse(e.expr) }

/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("reverse function") {
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on

// String test cases
val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")

checkAnswer(
oneRowDF.select(reverse('s)),
Seq(Row("krapS"))
)
checkAnswer(
oneRowDF.selectExpr("reverse(s)"),
Seq(Row("krapS"))
)
checkAnswer(
oneRowDF.select(reverse('i)),
Seq(Row("5123"))
)
checkAnswer(
oneRowDF.selectExpr("reverse(i)"),
Seq(Row("5123"))
)
checkAnswer(
oneRowDF.selectExpr("reverse(null)"),
Seq(Row(null))
)

// Array test cases (primitive-type elements)
val idf = Seq(
Seq(1, 9, 8, 7),
Seq(5, 8, 9, 7, 2),
Seq.empty,
null
).toDF("i")

checkAnswer(
idf.select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.filter(dummyFilter('i)).select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.selectExpr("reverse(i)"),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
oneRowDF.selectExpr("reverse(array(1, null, 2, null))"),
Seq(Row(Seq(null, 2, null, 1)))
)
checkAnswer(
oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"),
Seq(Row(Seq(null, 2, null, 1)))
)

// Array test cases (non-primitive-type elements)
val sdf = Seq(
Seq("c", "a", "b"),
Seq("b", null, "c", null),
Seq.empty,
null
).toDF("s")

checkAnswer(
sdf.select(reverse('s)),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
sdf.filter(dummyFilter('s)).select(reverse('s)),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
sdf.selectExpr("reverse(s)"),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
)
checkAnswer(
oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
)

// Error test cases
intercept[AnalysisException] {
oneRowDF.selectExpr("reverse(struct(1, 'a'))")
}
intercept[AnalysisException] {
oneRowDF.selectExpr("reverse(map(1, 'a'))")
}
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down

0 comments on commit 0aef4f8

Please sign in to comment.