Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-42168][SQL][PYTHON][FOLLOW-UP] Test FlatMapCoGroupsInPandas with Window function #39752

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import unittest
from typing import cast

from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf
from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum
from pyspark.sql.types import DoubleType, StructType, StructField, Row
from pyspark.sql.window import Window
from pyspark.errors import IllegalArgumentException, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
Expand Down Expand Up @@ -365,6 +366,50 @@ def test_self_join(self):

self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())

def test_with_window_function(self):
# SPARK-42168: a window function with same partition keys but differing key order
ids = 2
days = 100
vals = 10000
parts = 10

id_df = self.spark.range(ids)
day_df = self.spark.range(days).withColumnRenamed("id", "day")
vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
df = id_df.join(day_df).join(vals_df)

left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
# SPARK-42132: this bug requires us to alias all columns from df here
right_df = df.select(
col("id").alias("id"), col("day").alias("day"), col("value").alias("right")
).repartition(parts).cache()

# note the column order is different to the groupBy("id", "day") column order below
window = Window.partitionBy("day", "id")

left_grouped_df = left_df.groupBy("id", "day")
right_grouped_df = right_df \
.withColumn("day_sum", sum(col("day")).over(window)) \
.groupBy("id", "day")

def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame([{
"id": left["id"][0] if not left.empty else (
right["id"][0] if not right.empty else None
),
"day": left["day"][0] if not left.empty else (
right["day"][0] if not right.empty else None
),
"lefts": len(left.index),
"rights": len(right.index)
}])

df = left_grouped_df.cogroup(right_grouped_df) \
.applyInPandas(cogroup, schema="id long, day long, lefts integer, rights integer")

actual = df.orderBy("id", "day").take(days)
self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)])

@staticmethod
def _test_with_key(left, right, isLeft):
def right_assign_key(key, lft, rgt):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@

package org.apache.spark.sql.execution.exchange

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class EnsureRequirementsSuite extends SharedSparkSession {
private val exprA = Literal(1)
Expand Down Expand Up @@ -1104,6 +1109,57 @@ class EnsureRequirementsSuite extends SharedSparkSession {
}
}

test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") {
val lKey = AttributeReference("key", IntegerType)()
val lKey2 = AttributeReference("key2", IntegerType)()

val rKey = AttributeReference("key", IntegerType)()
val rKey2 = AttributeReference("key2", IntegerType)()
val rValue = AttributeReference("value", IntegerType)()

val left = DummySparkPlan()
val right = WindowExec(
Alias(
WindowExpression(
Sum(rValue).toAggregateExpression(),
WindowSpecDefinition(
Seq(rKey2, rKey),
Nil,
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
)
), "sum")() :: Nil,
Seq(rKey2, rKey),
Nil,
DummySparkPlan()
)

val pythonUdf = PythonUDF("pyUDF", null,
StructType(Seq(StructField("value", IntegerType))),
Seq.empty,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
true)

val flapMapCoGroup = FlatMapCoGroupsInPandasExec(
Seq(lKey, lKey2),
Seq(rKey, rKey2),
pythonUdf,
AttributeReference("value", IntegerType)() :: Nil,
left,
right
)

val result = EnsureRequirements.apply(flapMapCoGroup)
result match {
case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _,
SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) =>
assert(leftKeys === Seq(lKey, lKey2))
assert(rightKeys === Seq(rKey, rKey2))
assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder)
assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder)
case other => fail(other.toString)
}
}

def bucket(numBuckets: Int, expr: Expression): TransformExpression = {
TransformExpression(BucketFunction, Seq(expr), Some(numBuckets))
}
Expand Down