Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Apr 19, 2024
1 parent 9f1f1bd commit 692a302
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 11 deletions.
8 changes: 4 additions & 4 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def __getattr__(self, name: str) -> Column:
return Column(jc)

def __dir__(self) -> List[str]:
attrs = set(super().__dir__())
attrs = set(dir(DataFrame))
attrs.update(filter(lambda s: s.isidentifier(), self.columns))
return sorted(attrs)

Expand Down Expand Up @@ -1953,15 +1953,15 @@ def sampleBy(
def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.classic.dataframe
import pyspark.sql.dataframe

globs = pyspark.sql.classic.dataframe.__dict__.copy()
globs = pyspark.sql.dataframe.__dict__.copy()
spark = (
SparkSession.builder.master("local[4]").appName("sql.classic.dataframe tests").getOrCreate()
)
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.classic.dataframe,
pyspark.sql.dataframe,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
Expand Down
29 changes: 25 additions & 4 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,22 @@ def sampleBy(
session=self._session,
)

def _ipython_key_completions_(self) -> List[str]:
"""Returns the names of columns in this :class:`DataFrame`.
Examples
--------
>>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age", "name"])
>>> df._ipython_key_completions_()
['age', 'name']
Would return illegal identifiers.
>>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age 1", "name?1"])
>>> df._ipython_key_completions_()
['age 1', 'name?1']
"""
return self.columns

def __getattr__(self, name: str) -> "Column":
if name in ["_jseq", "_jdf", "_jmap", "_jcols", "rdd", "toJSON"]:
raise PySparkAttributeError(
Expand Down Expand Up @@ -1738,7 +1754,7 @@ def _col(self, name: str) -> Column:
)

def __dir__(self) -> List[str]:
attrs = set(super().__dir__())
attrs = set(dir(DataFrame))
attrs.update(self.columns)
return sorted(attrs)

Expand Down Expand Up @@ -2149,11 +2165,16 @@ def _test() -> None:
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.dataframe
import pyspark.sql.dataframe

os.chdir(os.environ["SPARK_HOME"])

globs = pyspark.sql.connect.dataframe.__dict__.copy()
globs = pyspark.sql.dataframe.__dict__.copy()

del pyspark.sql.dataframe.DataFrame.toJSON.__doc__
del pyspark.sql.dataframe.DataFrame.rdd.__doc__
del pyspark.sql.dataframe.DataFrame.checkpoint.__doc__
del pyspark.sql.dataframe.DataFrame.localCheckpoint.__doc__

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
Expand All @@ -2162,7 +2183,7 @@ def _test() -> None:
)

(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.dataframe,
pyspark.sql.dataframe,
globs=globs,
optionflags=doctest.ELLIPSIS
| doctest.NORMALIZE_WHITESPACE
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ class DataFrame:
>>> people.filter(people.age > 30).join(
... department, people.deptId == department.id).groupBy(
... department.name, "gender").agg({"salary": "avg", "age": "max"}).show()
... department.name, "gender").agg(
... {"salary": "avg", "age": "max"}).sort("max(age)").show()
+-------+------+-----------+--------+
| name|gender|avg(salary)|max(age)|
+-------+------+-----------+--------+
| ML| F| 150.0| 60|
|PySpark| M| 75.0| 50|
| ML| F| 150.0| 60|
+-------+------+-----------+--------+
Notes
Expand Down Expand Up @@ -5295,7 +5296,7 @@ def _ipython_key_completions_(self) -> List[str]:
>>> df._ipython_key_completions_()
['age 1', 'name?1']
"""
return self.columns
...

@dispatch_df_method
def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame":
Expand Down

0 comments on commit 692a302

Please sign in to comment.