Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve a possible explodeMap column names collision #582

Merged
merged 1 commit into from
Nov 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1234,20 +1234,30 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
import org.apache.spark.sql.functions.{explode => sparkExplode, struct => sparkStruct, col => sparkCol}
val df = dataset.toDF()

// preserve the original list of columns
val columns = df.columns.toSeq.map(sparkCol)
// select all columns, all original columns and [key, value] columns appeared after the map explode
// .withColumn(column.value.name, sparkExplode(df(column.value.name))) in this case would not work
// since the map explode produces two columns
val exploded = df.select(sparkCol("*"), sparkExplode(df(column.value.name)))
val columnNames = df.columns.toSeq
val columnNamesRenamed = columnNames.map(c => s"frameless_$c")

// preserve the original list of renamed columns
val columns = columnNamesRenamed.map(sparkCol)

val columnRenamed = s"frameless_${column.value.name}"
// explode of a map adds "key" and "value" columns into the Row
// this may cause col namings collision: row could already contain key / value columns
// we rename the original Row columns to avoid this collision
val dfr = df.toDF(columnNamesRenamed: _*)
val exploded = dfr.select(sparkCol("*"), sparkExplode(dfr(columnRenamed)))
val trans =
exploded
// map explode explodes it into [key, value] columns
// the only way to put it into a column is to create a struct
// TODO: handle org.apache.spark.sql.AnalysisException: Reference 'key / value' is ambiguous, could be: key / value, key / value
.withColumn(column.value.name, sparkStruct(exploded("key"), exploded("value")))
.withColumn(columnRenamed, sparkStruct(exploded("key"), exploded("value")))
// selecting only original columns, we don't need [key, value] columns left in the DataFrame after the map explode
.select(columns: _*)
// rename columns back and form the result
.toDF(columnNames: _*)
.as[Out](TypedExpressionEncoder[Out])
TypedDataset.create[Out](trans)
}
Expand Down
15 changes: 15 additions & 0 deletions dataset/src/test/scala/frameless/ExplodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,19 @@ class ExplodeTests extends TypedDatasetSuite {
check(forAll(prop[String, Int, Long] _))
check(forAll(prop[Long, String, Int] _))
}

test("explode on maps making sure no key / value naming collision happens") {
def prop[K: TypedEncoder: ClassTag, V: TypedEncoder: ClassTag, A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X3KV[K, V, Map[A, B]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.explodeMap('c).collect().run().toVector
val scalaResults = xs.flatMap { x3 => x3.c.toList.map((x3.key, x3.value, _)) }.toVector

framelessResults ?= scalaResults
}

check(forAll(prop[String, Int, Long, String] _))
check(forAll(prop[Long, String, Int, Long] _))
check(forAll(prop[Int, Long, String, Int] _))
}
}
13 changes: 13 additions & 0 deletions dataset/src/test/scala/frameless/XN.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ object X3U {
Ordering.Tuple3[A, B, C].on(x => (x.a, x.b, x.c))
}

case class X3KV[A, B, C](key: A, value: B, c: C)

object X3KV {
implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary]: Arbitrary[X3KV[A, B, C]] =
Arbitrary(Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3KV.apply[A, B, C] _).tupled))

implicit def cogen[A, B, C](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C]): Cogen[X3KV[A, B, C]] =
Cogen.tuple3(A, B, C).contramap(x => (x.key, x.value, x.c))

implicit def ordering[A: Ordering, B: Ordering, C: Ordering]: Ordering[X3KV[A, B, C]] =
Ordering.Tuple3[A, B, C].on(x => (x.key, x.value, x.c))
}

case class X4[A, B, C, D](a: A, b: B, c: C, d: D)

object X4 {
Expand Down