Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-40852][CONNECT][PYTHON] Introduce StatFunction in proto and implement DataFrame.summary #38318

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/relations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ message Relation {
SubqueryAlias subquery_alias = 16;
Repartition repartition = 17;

StatFunction stat_function = 100;

Unknown unknown = 999;
}
}
Expand Down Expand Up @@ -254,3 +256,21 @@ message Repartition {
// Optional. Default value is false.
bool shuffle = 3;
}

// StatFunction
message StatFunction {
// Required. The input relation.
Relation input = 1;
// Required. The function and its parameters.
oneof function {
Summary summary = 2;

Unknown unknown = 999;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's for enum but here is an optional field... cc @amaliujia

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: will we add new functions under this oneof?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, such as crosstab cov corr etc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then this makes sense

}

// StatFunctions.summary
message Summary {
repeated string statistics = 1;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,22 @@ package object dsl {
Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true))
.build()

def summary(statistics: String*): Relation = {
Relation
.newBuilder()
.setStatFunction(
proto.StatFunction
.newBuilder()
.setInput(logicalPlan)
.setSummary(
proto.StatFunction.Summary
.newBuilder()
.addAllStatistics(statistics.toSeq.asJava)
.build())
.build())
.build()
}

private def createSetOperation(
left: Relation,
right: Relation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.annotation.elidable.byName
import scala.collection.JavaConverters._

import org.apache.spark.connect.proto
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
Expand All @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType,
import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LogicalPlan, Sample, SubqueryAlias, Union}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -73,6 +74,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
transformSubqueryAlias(rel.getSubqueryAlias)
case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
case proto.Relation.RelTypeCase.STAT_FUNCTION =>
transformStatFunction(rel.getStatFunction)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
Expand Down Expand Up @@ -124,6 +127,19 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
logical.Range(start, end, step, numPartitions)
}

private def transformStatFunction(rel: proto.StatFunction): LogicalPlan = {
val child = transformRelation(rel.getInput)

rel.getFunctionCase match {
case proto.StatFunction.FunctionCase.SUMMARY =>
StatFunctions
.summary(Dataset.ofRows(session, child), rel.getSummary.getStatisticsList.asScala.toSeq)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now but it's going to truncate the SQL plans that disable further optimization. We should probably add dedicated plans for def summary in Dataset itself.

For now, LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, then it will have more optimization space. let us add new plan for it. thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1!

I don't know how to add a new plan. It would be very useful to have a PR as an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cloud-fan @HyukjinKwon
since we had reimplemented the df.summary 6a0713a, are there some differences in sql optimization between this method (directly invoke df.summary) and adding a dedicated plan?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some rules may not work as they don't recognize the new plan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by truncate the SQL plans? DataFrame transformations just accumulate the logical plan.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan the old df.summary eagerly compute the statistics and always return a LocalRelation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's an issue. Can it be solved by updating df.summary implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it has been resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new impl

.logicalPlan

case _ => throw InvalidPlanInput(s"StatFunction ${rel.getUnknown} not supported.")
}
}

private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan2, sparkPlan2)
}

test("Test summary") {
comparePlans(
connectTestRelation.summary("count", "mean", "stddev"),
sparkTestRelation.summary("count", "mean", "stddev"))
}

private def createLocalRelationProtoByQualifiedAttributes(
attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) ->
def where(self, condition: Expression) -> "DataFrame":
return self.filter(condition)

def summary(self, *statistics: str) -> "DataFrame":
_statistics: List[str] = list(statistics)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

different from

if len(statistics) == 1 and isinstance(statistics[0], list):
statistics = statistics[0]
since i think that preprocessing weird

for s in _statistics:
if not isinstance(s, str):
raise TypeError(f"'statistics' must be list[str], but got {type(s).__name__}")
return DataFrame.withPlan(
plan.StatFunction(child=self._plan, function="summary", statistics=_statistics),
session=self._session,
)

def _get_alias(self) -> Optional[str]:
p = self._plan
while p is not None:
Expand Down
38 changes: 38 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from typing import (
Any,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -750,3 +751,40 @@ def _repr_html_(self) -> str:
</li>
</uL>
"""


class StatFunction(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], function: str, **kwargs: Any) -> None:
super().__init__(child)
assert function in ["summary"]
self.function = function
self.kwargs = kwargs

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
assert self._child is not None

plan = proto.Relation()
plan.stat_function.input.CopyFrom(self._child.plan(session))

if self.function == "summary":
plan.stat_function.summary.statistics.extend(self.kwargs.get("statistics", []))
else:
raise Exception(f"Unknown function ${self.function}.")

return plan

def print(self, indent: int = 0) -> str:
i = " " * indent
return f"""{i}<StatFunction function='{self.function}' augments='{self.kwargs}'>"""

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>StatFunction</b><br />
Function: {self.function} <br />
Augments: {self.kwargs} <br />
{self._child_repr_()}
</li>
</ul>
"""
Loading