Skip to content

Commit

Permalink
[SPARK-35278][SQL] Invoke should find the method with correct number …
Browse files Browse the repository at this point in the history
…of parameters

### What changes were proposed in this pull request?

This patch fixes `Invoke` expression when the target object has more than one method with the given method name.

### Why are the changes needed?

`Invoke` will find out the method on the target object with given method name. If there are more than one method with the name, currently it is undeterministic which method will be used. We should add the condition of parameter number when finding the method.

### Does this PR introduce _any_ user-facing change?

Yes, fixed a bug when using `Invoke` on a object where more than one method with the given method name.

### How was this patch tested?

Unit test.

Closes apache#32404 from viirya/verify-invoke-param-len.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
  • Loading branch information
viirya committed May 1, 2021
1 parent 72e238a commit 6ce1b16
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,30 @@ case class Invoke(

@transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
val m = cls.getMethods.find(_.getName == encodedFunctionName)
if (m.isEmpty) {
sys.error(s"Couldn't find $encodedFunctionName on $cls")
} else {
m
// Looking with function name + argument classes first.
try {
Some(cls.getMethod(encodedFunctionName, argClasses: _*))
} catch {
case _: NoSuchMethodException =>
// For some cases, e.g. arg class is Object, `getMethod` cannot find the method.
// We look at function name + argument length
val m = cls.getMethods.filter { m =>
m.getName == encodedFunctionName && m.getParameterCount == arguments.length
}
if (m.isEmpty) {
sys.error(s"Couldn't find $encodedFunctionName on $cls")
} else if (m.length > 1) {
// More than one matched method signature. Exclude synthetic one, e.g. generic one.
val realMethods = m.filter(!_.isSynthetic)
if (realMethods.length > 1) {
// Ambiguous case, we don't know which method to choose, just fail it.
sys.error(s"Found ${realMethods.length} $encodedFunctionName on $cls")
} else {
Some(realMethods.head)
}
} else {
Some(m.head)
}
}
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,29 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkExceptionInExpression[ArithmeticException](
StaticInvoke(mathCls, IntegerType, "addExact", Seq(Literal(Int.MaxValue), Literal(1))), "")
}

test("SPARK-35278: invoke should find method with correct number of parameters") {
val strClsType = ObjectType(classOf[String])
checkExceptionInExpression[StringIndexOutOfBoundsException](
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(3))), "")

checkObjectExprEvaluation(
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0))), "a")

checkExceptionInExpression[StringIndexOutOfBoundsException](
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(3))), "")

checkObjectExprEvaluation(
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(1))), "a")
}

test("SPARK-35278: invoke should correctly invoke override method") {
val clsType = ObjectType(classOf[ConcreteClass])
val obj = new ConcreteClass

checkObjectExprEvaluation(
Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 0)
}
}

class TestBean extends Serializable {
Expand All @@ -628,3 +651,11 @@ class TestBean extends Serializable {
def setNonPrimitive(i: AnyRef): Unit =
assert(i != null, "this setter should not be called with null.")
}

abstract class BaseClass[T] {
def testFunc(param: T): T
}

class ConcreteClass extends BaseClass[Int] with Serializable {
override def testFunc(param: Int): Int = param - 1
}

0 comments on commit 6ce1b16

Please sign in to comment.