diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index 845e52c9..7f459684 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -1234,20 +1234,30 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val import org.apache.spark.sql.functions.{explode => sparkExplode, struct => sparkStruct, col => sparkCol} val df = dataset.toDF() - // preserve the original list of columns - val columns = df.columns.toSeq.map(sparkCol) // select all columns, all original columns and [key, value] columns appeared after the map explode // .withColumn(column.value.name, sparkExplode(df(column.value.name))) in this case would not work // since the map explode produces two columns - val exploded = df.select(sparkCol("*"), sparkExplode(df(column.value.name))) + val columnNames = df.columns.toSeq + val columnNamesRenamed = columnNames.map(c => s"frameless_$c") + + // preserve the original list of renamed columns + val columns = columnNamesRenamed.map(sparkCol) + + val columnRenamed = s"frameless_${column.value.name}" + // explode of a map adds "key" and "value" columns into the Row + // this may cause col namings collision: row could already contain key / value columns + // we rename the original Row columns to avoid this collision + val dfr = df.toDF(columnNamesRenamed: _*) + val exploded = dfr.select(sparkCol("*"), sparkExplode(dfr(columnRenamed))) val trans = exploded // map explode explodes it into [key, value] columns // the only way to put it into a column is to create a struct - // TODO: handle org.apache.spark.sql.AnalysisException: Reference 'key / value' is ambiguous, could be: key / value, key / value - .withColumn(column.value.name, sparkStruct(exploded("key"), exploded("value"))) + .withColumn(columnRenamed, sparkStruct(exploded("key"), exploded("value"))) // selecting only original columns, we don't need [key, value] columns left in the DataFrame after the map explode .select(columns: _*) + // rename columns back and form the result + .toDF(columnNames: _*) .as[Out](TypedExpressionEncoder[Out]) TypedDataset.create[Out](trans) } diff --git a/dataset/src/test/scala/frameless/ExplodeTests.scala b/dataset/src/test/scala/frameless/ExplodeTests.scala index 205b4903..3078ceb1 100644 --- a/dataset/src/test/scala/frameless/ExplodeTests.scala +++ b/dataset/src/test/scala/frameless/ExplodeTests.scala @@ -77,4 +77,19 @@ class ExplodeTests extends TypedDatasetSuite { check(forAll(prop[String, Int, Long] _)) check(forAll(prop[Long, String, Int] _)) } + + test("explode on maps making sure no key / value naming collision happens") { + def prop[K: TypedEncoder: ClassTag, V: TypedEncoder: ClassTag, A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X3KV[K, V, Map[A, B]]]): Prop = { + val tds = TypedDataset.create(xs) + + val framelessResults = tds.explodeMap('c).collect().run().toVector + val scalaResults = xs.flatMap { x3 => x3.c.toList.map((x3.key, x3.value, _)) }.toVector + + framelessResults ?= scalaResults + } + + check(forAll(prop[String, Int, Long, String] _)) + check(forAll(prop[Long, String, Int, Long] _)) + check(forAll(prop[Int, Long, String, Int] _)) + } } diff --git a/dataset/src/test/scala/frameless/XN.scala b/dataset/src/test/scala/frameless/XN.scala index 0aa7f728..c23d4b45 100644 --- a/dataset/src/test/scala/frameless/XN.scala +++ b/dataset/src/test/scala/frameless/XN.scala @@ -52,6 +52,19 @@ object X3U { Ordering.Tuple3[A, B, C].on(x => (x.a, x.b, x.c)) } +case class X3KV[A, B, C](key: A, value: B, c: C) + +object X3KV { + implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary]: Arbitrary[X3KV[A, B, C]] = + Arbitrary(Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3KV.apply[A, B, C] _).tupled)) + + implicit def cogen[A, B, C](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C]): Cogen[X3KV[A, B, C]] = + Cogen.tuple3(A, B, C).contramap(x => (x.key, x.value, x.c)) + + implicit def ordering[A: Ordering, B: Ordering, C: Ordering]: Ordering[X3KV[A, B, C]] = + Ordering.Tuple3[A, B, C].on(x => (x.key, x.value, x.c)) +} + case class X4[A, B, C, D](a: A, b: B, c: C, d: D) object X4 {