Skip to content

Commit

Permalink
Fixes #16
Browse files Browse the repository at this point in the history
Detailed research revealed that we're incorrectly inferring field names
if they don't confirm to javaBean specification. It turns out that we're
inferring field name from getter which is not always correct.

So in this commit we're adding new wrapper KStructField, which contains
information not only on field name, but also information on getter name
which gives us abiity to filter data not by inferred field name, but
by getter name which is much more safe - there is no way kotlin will
allow the scapitalized and non capitalized properties to co-exist in
same class, like `Country` and `country`, so usage of getter looks more
correct way to handle this situation.

Signed-off-by: Pasha Finkelshteyn <asm0dey@jetbrains.com>
  • Loading branch information
asm0dey committed Jun 23, 2020
1 parent 91afeff commit 8757094
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,11 @@ object KotlinReflection extends KotlinReflection {
val cls = dataType.cls
val properties = getJavaBeanReadableProperties(cls)
val fields = properties.map { prop =>
val fieldName = prop.getName
val maybeField = dataType.dt.fields.find(it => it.name == fieldName)

val maybeField = dataType.dt.fields.map(_.asInstanceOf[KStructField]).find(it => it.getterName == prop.getReadMethod.getName)
if (maybeField.isEmpty)
throw new IllegalArgumentException(s"Field $fieldName is not found among available fields, which are: ${dataType.dt.fields.map(_.name).mkString(", ")}")
throw new IllegalArgumentException(s"Field ${prop.getName} is not found among available fields, which are: ${dataType.dt.fields.map(_.name).mkString(", ")}")
val fieldName = maybeField.get.name
val propClass = maybeField.map(it => it.dataType.asInstanceOf[DataTypeWithClass].cls).get
val propDt = maybeField.map(it => it.dataType.asInstanceOf[DataTypeWithClass]).get

Expand Down
30 changes: 30 additions & 0 deletions core/src/main/scala/org/apache/spark/sql/KotlinWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,36 @@ case class KSimpleTypeWrapper(dt: DataType, cls: Class[_], nullable: Boolean) ex
override private[spark] def asNullable = dt.asNullable
}

class KStructField(val getterName: String, val delegate: StructField) extends StructField {
override private[sql] def buildFormattedString(prefix: String, stringConcat: StringUtils.StringConcat, maxDepth: Int): Unit = delegate.buildFormattedString(prefix, stringConcat, maxDepth)

override def toString(): String = delegate.toString()

override private[sql] def jsonValue = delegate.jsonValue

override def withComment(comment: String): StructField = delegate.withComment(comment)

override def getComment(): Option[String] = delegate.getComment()

override def toDDL: String = delegate.toDDL

override def productElement(n: Int): Any = delegate.productElement(n)

override def productArity: Int = delegate.productArity

override def productIterator: Iterator[Any] = delegate.productIterator

override def productPrefix: String = delegate.productPrefix

override val dataType: DataType = delegate.dataType

override def canEqual(that: Any): Boolean = delegate.canEqual(that)

override val metadata: Metadata = delegate.metadata
override val name: String = delegate.name
override val nullable: Boolean = delegate.nullable
}

object helpme {

def listToSeq(i: java.util.List[_]): Seq[_] = Seq(i.toArray: _*)
Expand Down
30 changes: 15 additions & 15 deletions kotlin-spark-api/src/main/kotlin/org/jetbrains/spark/api/ApiV1.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@

package org.jetbrains.spark.api

import org.apache.spark.SparkContext
import org.apache.spark.api.java.function.*
import org.apache.spark.sql.*
import org.apache.spark.sql.Encoders.*
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.types.*
import org.jetbrains.spark.extensions.KSparkExtensions
import scala.reflect.ClassTag
import java.beans.PropertyDescriptor
import java.math.BigDecimal
import java.sql.Date
import java.sql.Timestamp
Expand Down Expand Up @@ -281,20 +281,20 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
mapValueParam.isMarkedNullable
)
}
else -> KDataTypeWrapper(
StructType(
klass
.declaredMemberProperties
.filter { it.findAnnotation<Transient>() == null }
.map {
val projectedType = types[it.returnType.toString()] ?: it.returnType
StructField(it.name, schema(projectedType, types), projectedType.isMarkedNullable, Metadata.empty())
}
.toTypedArray()
),
klass.java,
true
)
else -> {
val structType = StructType(
klass
.declaredMemberProperties
.filter { it.findAnnotation<Transient>() == null }
.map {
val projectedType = types[it.returnType.toString()] ?: it.returnType
val propertyDescriptor = PropertyDescriptor(it.name, klass.java, "is" + it.name.capitalize(), null)
KStructField(propertyDescriptor.readMethod.name, StructField(it.name, schema(projectedType, types), projectedType.isMarkedNullable, Metadata.empty()))
}
.toTypedArray()
)
KDataTypeWrapper(structType, klass.java, true)
}
}
}

Expand Down

0 comments on commit 8757094

Please sign in to comment.