Skip to content

Commit

Permalink
[SPARK-40852][CONNECT][PYTHON] Introduce StatFunction in proto and …
Browse files Browse the repository at this point in the history
…implement `DataFrame.summary`

### What changes were proposed in this pull request?
 Implement `DataFrame.summary`

there is a set of DataFrame APIs implemented in [`StatFunctions`](https://github.com/apache/spark/blob/9cae423075145d3dd81d53f4b82d4f2af6fe7c15/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala),  [`DataFrameStatFunctions`](https://github.com/apache/spark/blob/b69c26833c99337bb17922f21dd72ee3a12e0c0a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala) and [`DataFrameNaFunctions`](https://github.com/apache/spark/blob/5d74ace648422e7a9bff7774ac266372934023b9/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala), which I think can not be implemented in connect client due to:

1. depend on Catalyst's analysis (most of them);
~~2. implemented in RDD operations (like `summary`,`approxQuantile`);~~ (resolved by reimpl)
~~3. internally trigger jobs (like `summary`);~~ (resolved by reimpl)

This PR introduced a new proto `StatFunction`  to support  `StatFunctions` method

### Why are the changes needed?
for Connect API coverage

### Does this PR introduce _any_ user-facing change?
yes, new API

### How was this patch tested?
added UT

Closes apache#38318 from zhengruifeng/connect_df_summary.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
zhengruifeng authored and SandishKumarHN committed Dec 12, 2022
1 parent 896a465 commit 956e140
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 59 deletions.
20 changes: 20 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/relations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ message Relation {
SubqueryAlias subquery_alias = 16;
Repartition repartition = 17;

StatFunction stat_function = 100;

Unknown unknown = 999;
}
}
Expand Down Expand Up @@ -254,3 +256,21 @@ message Repartition {
// Optional. Default value is false.
bool shuffle = 3;
}

// StatFunction
message StatFunction {
// Required. The input relation.
Relation input = 1;
// Required. The function and its parameters.
oneof function {
Summary summary = 2;

Unknown unknown = 999;
}

// StatFunctions.summary
message Summary {
repeated string statistics = 1;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,22 @@ package object dsl {
Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true))
.build()

def summary(statistics: String*): Relation = {
Relation
.newBuilder()
.setStatFunction(
proto.StatFunction
.newBuilder()
.setInput(logicalPlan)
.setSummary(
proto.StatFunction.Summary
.newBuilder()
.addAllStatistics(statistics.toSeq.asJava)
.build())
.build())
.build()
}

private def createSetOperation(
left: Relation,
right: Relation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.annotation.elidable.byName
import scala.collection.JavaConverters._

import org.apache.spark.connect.proto
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
Expand All @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType,
import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LogicalPlan, Sample, SubqueryAlias, Union}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -73,6 +74,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
transformSubqueryAlias(rel.getSubqueryAlias)
case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
case proto.Relation.RelTypeCase.STAT_FUNCTION =>
transformStatFunction(rel.getStatFunction)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
Expand Down Expand Up @@ -124,6 +127,19 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
logical.Range(start, end, step, numPartitions)
}

private def transformStatFunction(rel: proto.StatFunction): LogicalPlan = {
val child = transformRelation(rel.getInput)

rel.getFunctionCase match {
case proto.StatFunction.FunctionCase.SUMMARY =>
StatFunctions
.summary(Dataset.ofRows(session, child), rel.getSummary.getStatisticsList.asScala.toSeq)
.logicalPlan

case _ => throw InvalidPlanInput(s"StatFunction ${rel.getUnknown} not supported.")
}
}

private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan2, sparkPlan2)
}

test("Test summary") {
comparePlans(
connectTestRelation.summary("count", "mean", "stddev"),
sparkTestRelation.summary("count", "mean", "stddev"))
}

private def createLocalRelationProtoByQualifiedAttributes(
attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) ->
def where(self, condition: Expression) -> "DataFrame":
return self.filter(condition)

def summary(self, *statistics: str) -> "DataFrame":
_statistics: List[str] = list(statistics)
for s in _statistics:
if not isinstance(s, str):
raise TypeError(f"'statistics' must be list[str], but got {type(s).__name__}")
return DataFrame.withPlan(
plan.StatFunction(child=self._plan, function="summary", statistics=_statistics),
session=self._session,
)

def _get_alias(self) -> Optional[str]:
p = self._plan
while p is not None:
Expand Down
38 changes: 38 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from typing import (
Any,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -750,3 +751,40 @@ def _repr_html_(self) -> str:
</li>
</uL>
"""


class StatFunction(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], function: str, **kwargs: Any) -> None:
super().__init__(child)
assert function in ["summary"]
self.function = function
self.kwargs = kwargs

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
assert self._child is not None

plan = proto.Relation()
plan.stat_function.input.CopyFrom(self._child.plan(session))

if self.function == "summary":
plan.stat_function.summary.statistics.extend(self.kwargs.get("statistics", []))
else:
raise Exception(f"Unknown function ${self.function}.")

return plan

def print(self, indent: int = 0) -> str:
i = " " * indent
return f"""{i}<StatFunction function='{self.function}' augments='{self.kwargs}'>"""

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>StatFunction</b><br />
Function: {self.function} <br />
Augments: {self.kwargs} <br />
{self._child_repr_()}
</li>
</ul>
"""
Loading

0 comments on commit 956e140

Please sign in to comment.