Skip to content

Commit

Permalink
[SPARK-46048][PYTHON][CONNECT] Support DataFrame.groupingSets in Pyth…
Browse files Browse the repository at this point in the history
…on Spark Connect

### What changes were proposed in this pull request?

This PR adds `DataFrame.groupingSets` in Python Spark Connect.

### Why are the changes needed?

For feature parity with non-Spark Connect.

### Does this PR introduce _any_ user-facing change?

Yes, it adds the new API `DataFframe.groupingSets` in Python Spark Connect.

### How was this patch tested?

Unittests were added.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #43967 from HyukjinKwon/SPARK-46048.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Nov 23, 2023
1 parent f6e4a46 commit d14410c
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,16 @@ message Aggregate {
// (Optional) Pivots a column of the current `DataFrame` and performs the specified aggregation.
Pivot pivot = 5;

// (Optional) List of values that will be translated to columns in the output DataFrame.
repeated GroupingSets grouping_sets = 6;

enum GroupType {
GROUP_TYPE_UNSPECIFIED = 0;
GROUP_TYPE_GROUPBY = 1;
GROUP_TYPE_ROLLUP = 2;
GROUP_TYPE_CUBE = 3;
GROUP_TYPE_PIVOT = 4;
GROUP_TYPE_GROUPING_SETS = 5;
}

message Pivot {
Expand All @@ -345,6 +349,11 @@ message Aggregate {
// the distinct values of the column.
repeated Expression.Literal values = 2;
}

message GroupingSets {
// (Required) Individual grouping set
repeated Expression grouping_set = 1;
}
}

// Relation of type [[Sort]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,27 @@ package object dsl {
Relation.newBuilder().setAggregate(agg.build()).build()
}

def groupingSets(groupingSets: Seq[Seq[Expression]], groupingExprs: Expression*)(
aggregateExprs: Expression*): Relation = {
val agg = Aggregate.newBuilder()
agg.setInput(logicalPlan)
agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS)
for (groupingSet <- groupingSets) {
val groupingSetMsg = Aggregate.GroupingSets.newBuilder()
for (groupCol <- groupingSet) {
groupingSetMsg.addGroupingSet(groupCol)
}
agg.addGroupingSets(groupingSetMsg)
}
for (groupingExpr <- groupingExprs) {
agg.addGroupingExpressions(groupingExpr)
}
for (aggregateExpr <- aggregateExprs) {
agg.addAggregateExpressions(aggregateExpr)
}
Relation.newBuilder().setAggregate(agg.build()).build()
}

def except(otherPlan: Relation, isAll: Boolean): Relation = {
Relation
.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2445,6 +2445,17 @@ class SparkConnectPlanner(
aggregates = aggExprs,
child = input)

case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets =>
getGroupingSets.getGroupingSetList.asScala.toSeq.map(transformExpression)
}
logical.Aggregate(
groupingExpressions = Seq(
GroupingSets(
groupingSets = groupingSetsExprs,
userGivenGroupByExprs = groupingExprs)),
aggregateExpressions = aliasedAgg,
child = input)
case other => throw InvalidPlanInput(s"Unknown Group Type $other")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,18 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan2, sparkPlan2)
}

test("GroupingSets expressions") {
val connectPlan1 =
connectTestRelation.groupingSets(Seq(Seq("id".protoAttr), Seq.empty), "id".protoAttr)(
proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build())
.as("agg1"))
val sparkPlan1 =
sparkTestRelation
.groupingSets(Seq(Seq(Column("id")), Seq.empty), Column("id"))
.agg(min(lit(1)).as("agg1"))
comparePlans(connectPlan1, sparkPlan1)
}

test("Test as(alias: String)") {
val connectPlan = connectTestRelation.as("target_table")
val sparkPlan = sparkTestRelation.as("target_table")
Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,45 @@ def cube(self, *cols: "ColumnOrName") -> "GroupedData":

cube.__doc__ = PySparkDataFrame.cube.__doc__

def groupingSets(
self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: "ColumnOrName"
) -> "GroupedData":
gsets: List[List[Column]] = []
for grouping_set in groupingSets:
gset: List[Column] = []
for c in grouping_set:
if isinstance(c, Column):
gset.append(c)
elif isinstance(c, str):
gset.append(self[c])
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_STR",
message_parameters={
"arg_name": "groupingSets",
"arg_type": type(c).__name__,
},
)
gsets.append(gset)

gcols: List[Column] = []
for c in cols:
if isinstance(c, Column):
gcols.append(c)
elif isinstance(c, str):
gcols.append(self[c])
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_STR",
message_parameters={"arg_name": "cols", "arg_type": type(c).__name__},
)

return GroupedData(
df=self, group_type="grouping_sets", grouping_cols=gcols, grouping_sets=gsets
)

groupingSets.__doc__ = PySparkDataFrame.groupingSets.__doc__

@overload
def head(self) -> Optional[Row]:
...
Expand Down
16 changes: 15 additions & 1 deletion python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@ def __init__(
grouping_cols: Sequence["Column"],
pivot_col: Optional["Column"] = None,
pivot_values: Optional[Sequence["LiteralType"]] = None,
grouping_sets: Optional[Sequence[Sequence["Column"]]] = None,
) -> None:
from pyspark.sql.connect.dataframe import DataFrame

assert isinstance(df, DataFrame)
self._df = df

assert isinstance(group_type, str) and group_type in ["groupby", "rollup", "cube", "pivot"]
assert isinstance(group_type, str) and group_type in [
"groupby",
"rollup",
"cube",
"pivot",
"grouping_sets",
]
self._group_type = group_type

assert isinstance(grouping_cols, list) and all(isinstance(g, Column) for g in grouping_cols)
Expand All @@ -83,6 +90,11 @@ def __init__(
self._pivot_col = pivot_col
self._pivot_values = pivot_values

self._grouping_sets: Optional[Sequence[Sequence["Column"]]] = None
if group_type == "grouping_sets":
assert grouping_sets is None or isinstance(grouping_sets, list)
self._grouping_sets = grouping_sets

def __repr__(self) -> str:
# the expressions are not resolved here,
# so the string representation can be different from vanilla PySpark.
Expand Down Expand Up @@ -130,6 +142,7 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
aggregate_cols=aggregate_cols,
pivot_col=self._pivot_col,
pivot_values=self._pivot_values,
grouping_sets=self._grouping_sets,
),
session=self._df._session,
)
Expand Down Expand Up @@ -171,6 +184,7 @@ def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
aggregate_cols=[_invoke_function(function, col(c)) for c in agg_cols],
pivot_col=self._pivot_col,
pivot_values=self._pivot_values,
grouping_sets=self._grouping_sets,
),
session=self._df._session,
)
Expand Down
23 changes: 21 additions & 2 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,17 @@ def __init__(
aggregate_cols: Sequence[Column],
pivot_col: Optional[Column],
pivot_values: Optional[Sequence[Any]],
grouping_sets: Optional[Sequence[Sequence[Column]]],
) -> None:
super().__init__(child)

assert isinstance(group_type, str) and group_type in ["groupby", "rollup", "cube", "pivot"]
assert isinstance(group_type, str) and group_type in [
"groupby",
"rollup",
"cube",
"pivot",
"grouping_sets",
]
self._group_type = group_type

assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)
Expand All @@ -795,12 +802,16 @@ def __init__(
if group_type == "pivot":
assert pivot_col is not None and isinstance(pivot_col, Column)
assert pivot_values is None or isinstance(pivot_values, list)
elif group_type == "grouping_sets":
assert grouping_sets is None or isinstance(grouping_sets, list)
else:
assert pivot_col is None
assert pivot_values is None
assert grouping_sets is None

self._pivot_col = pivot_col
self._pivot_values = pivot_values
self._grouping_sets = grouping_sets

def plan(self, session: "SparkConnectClient") -> proto.Relation:
from pyspark.sql.connect.functions import lit
Expand Down Expand Up @@ -829,7 +840,15 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan.aggregate.pivot.values.extend(
[lit(v).to_plan(session).literal for v in self._pivot_values]
)

elif self._group_type == "grouping_sets":
plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS
assert self._grouping_sets is not None
for grouping_set in self._grouping_sets:
plan.aggregate.grouping_sets.append(
proto.Aggregate.GroupingSets(
grouping_set=[c.to_plan(session) for c in grouping_set]
)
)
return plan


Expand Down
194 changes: 98 additions & 96 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

36 changes: 36 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1380,13 +1380,15 @@ class Aggregate(google.protobuf.message.Message):
GROUP_TYPE_ROLLUP: Aggregate._GroupType.ValueType # 2
GROUP_TYPE_CUBE: Aggregate._GroupType.ValueType # 3
GROUP_TYPE_PIVOT: Aggregate._GroupType.ValueType # 4
GROUP_TYPE_GROUPING_SETS: Aggregate._GroupType.ValueType # 5

class GroupType(_GroupType, metaclass=_GroupTypeEnumTypeWrapper): ...
GROUP_TYPE_UNSPECIFIED: Aggregate.GroupType.ValueType # 0
GROUP_TYPE_GROUPBY: Aggregate.GroupType.ValueType # 1
GROUP_TYPE_ROLLUP: Aggregate.GroupType.ValueType # 2
GROUP_TYPE_CUBE: Aggregate.GroupType.ValueType # 3
GROUP_TYPE_PIVOT: Aggregate.GroupType.ValueType # 4
GROUP_TYPE_GROUPING_SETS: Aggregate.GroupType.ValueType # 5

class Pivot(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
Expand Down Expand Up @@ -1423,11 +1425,35 @@ class Aggregate(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["col", b"col", "values", b"values"]
) -> None: ...

class GroupingSets(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

GROUPING_SET_FIELD_NUMBER: builtins.int
@property
def grouping_set(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
pyspark.sql.connect.proto.expressions_pb2.Expression
]:
"""(Required) Individual grouping set"""
def __init__(
self,
*,
grouping_set: collections.abc.Iterable[
pyspark.sql.connect.proto.expressions_pb2.Expression
]
| None = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["grouping_set", b"grouping_set"]
) -> None: ...

INPUT_FIELD_NUMBER: builtins.int
GROUP_TYPE_FIELD_NUMBER: builtins.int
GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
AGGREGATE_EXPRESSIONS_FIELD_NUMBER: builtins.int
PIVOT_FIELD_NUMBER: builtins.int
GROUPING_SETS_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for a RelationalGroupedDataset."""
Expand All @@ -1450,6 +1476,13 @@ class Aggregate(google.protobuf.message.Message):
@property
def pivot(self) -> global___Aggregate.Pivot:
"""(Optional) Pivots a column of the current `DataFrame` and performs the specified aggregation."""
@property
def grouping_sets(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___Aggregate.GroupingSets
]:
"""(Optional) List of values that will be translated to columns in the output DataFrame."""
def __init__(
self,
*,
Expand All @@ -1464,6 +1497,7 @@ class Aggregate(google.protobuf.message.Message):
]
| None = ...,
pivot: global___Aggregate.Pivot | None = ...,
grouping_sets: collections.abc.Iterable[global___Aggregate.GroupingSets] | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input", "pivot", b"pivot"]
Expand All @@ -1477,6 +1511,8 @@ class Aggregate(google.protobuf.message.Message):
b"group_type",
"grouping_expressions",
b"grouping_expressions",
"grouping_sets",
b"grouping_sets",
"input",
b"input",
"pivot",
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4204,7 +4204,6 @@ def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]

return GroupedData(jgd, self)

# TODO(SPARK-46048): Add it to Python Spark Connect client.
def groupingSets(
self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: "ColumnOrName"
) -> "GroupedData":
Expand Down

0 comments on commit d14410c

Please sign in to comment.