Skip to content

Commit

Permalink
[SPARK-41600][SPARK-41623][SPARK-41612][CONNECT] Implement Catalog.ca…
Browse files Browse the repository at this point in the history
…cheTable, isCached and uncache

### What changes were proposed in this pull request?

This PR adds three API below to Spark Connect
- `Catalog.isCached`
- `Catalog.cacheTable`
- `Catalog uncacheTable`

### Why are the changes needed?

These were not added because of the design concern (in its behaviour). However, we should provide the same API compatibility and behaivours with the regular PySpark in any event. So these are proposed back.

### Does this PR introduce _any_ user-facing change?

No to end users.
Yes to the dev because it adds three new API in Spark Connect.

### How was this patch tested?

Unittests were added.

Closes apache#39919 from HyukjinKwon/SPARK-41600-SPARK-41623-SPARK-41612.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Feb 7, 2023
1 parent 58b6535 commit 54b5cf6
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 235 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,9 @@ message Catalog {
DropTempView drop_temp_view = 15;
DropGlobalTempView drop_global_temp_view = 16;
RecoverPartitions recover_partitions = 17;
// TODO(SPARK-41612): Support Catalog.isCached
// IsCached is_cached = 18;
// TODO(SPARK-41600): Support Catalog.cacheTable
// CacheTable cache_table = 19;
// TODO(SPARK-41623): Support Catalog.uncacheTable
// UncacheTable uncache_table = 20;
IsCached is_cached = 18;
CacheTable cache_table = 19;
UncacheTable uncache_table = 20;
ClearCache clear_cache = 21;
RefreshTable refresh_table = 22;
RefreshByPath refresh_by_path = 23;
Expand Down Expand Up @@ -185,26 +182,23 @@ message RecoverPartitions {
string table_name = 1;
}

// TODO(SPARK-41612): Support Catalog.isCached
//// See `spark.catalog.isCached`
//message IsCached {
// // (Required)
// string table_name = 1;
//}
//
// TODO(SPARK-41600): Support Catalog.cacheTable
//// See `spark.catalog.cacheTable`
//message CacheTable {
// // (Required)
// string table_name = 1;
//}
//
// TODO(SPARK-41623): Support Catalog.uncacheTable
//// See `spark.catalog.uncacheTable`
//message UncacheTable {
// // (Required)
// string table_name = 1;
//}
// See `spark.catalog.isCached`
message IsCached {
// (Required)
string table_name = 1;
}

// See `spark.catalog.cacheTable`
message CacheTable {
// (Required)
string table_name = 1;
}

// See `spark.catalog.uncacheTable`
message UncacheTable {
// (Required)
string table_name = 1;
}

// See `spark.catalog.clearCache`
message ClearCache { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,10 @@ class SparkConnectPlanner(val session: SparkSession) {
transformDropGlobalTempView(catalog.getDropGlobalTempView)
case proto.Catalog.CatTypeCase.RECOVER_PARTITIONS =>
transformRecoverPartitions(catalog.getRecoverPartitions)
// TODO(SPARK-41612): Support Catalog.isCached
// case proto.Catalog.CatTypeCase.IS_CACHED => transformIsCached(catalog.getIsCached)
// TODO(SPARK-41600): Support Catalog.cacheTable
// case proto.Catalog.CatTypeCase.CACHE_TABLE => transformCacheTable(catalog.getCacheTable)
// TODO(SPARK-41623): Support Catalog.uncacheTable
// case proto.Catalog.CatTypeCase.UNCACHE_TABLE =>
// transformUncacheTable(catalog.getUncacheTable)
case proto.Catalog.CatTypeCase.IS_CACHED => transformIsCached(catalog.getIsCached)
case proto.Catalog.CatTypeCase.CACHE_TABLE => transformCacheTable(catalog.getCacheTable)
case proto.Catalog.CatTypeCase.UNCACHE_TABLE =>
transformUncacheTable(catalog.getUncacheTable)
case proto.Catalog.CatTypeCase.CLEAR_CACHE => transformClearCache(catalog.getClearCache)
case proto.Catalog.CatTypeCase.REFRESH_TABLE =>
transformRefreshTable(catalog.getRefreshTable)
Expand Down Expand Up @@ -1791,25 +1788,22 @@ class SparkConnectPlanner(val session: SparkSession) {
emptyLocalRelation
}

// TODO(SPARK-41612): Support Catalog.isCached
// private def transformIsCached(getIsCached: proto.IsCached): LogicalPlan = {
// session
// .createDataset(session.catalog.isCached(getIsCached.getTableName) :: Nil)(
// Encoders.scalaBoolean)
// .logicalPlan
// }
//
// TODO(SPARK-41600): Support Catalog.cacheTable
// private def transformCacheTable(getCacheTable: proto.CacheTable): LogicalPlan = {
// session.catalog.cacheTable(getCacheTable.getTableName)
// emptyLocalRelation
// }
//
// TODO(SPARK-41623): Support Catalog.uncacheTable
// private def transformUncacheTable(getUncacheTable: proto.UncacheTable): LogicalPlan = {
// session.catalog.uncacheTable(getUncacheTable.getTableName)
// emptyLocalRelation
// }
private def transformIsCached(getIsCached: proto.IsCached): LogicalPlan = {
session
.createDataset(session.catalog.isCached(getIsCached.getTableName) :: Nil)(
Encoders.scalaBoolean)
.logicalPlan
}

private def transformCacheTable(getCacheTable: proto.CacheTable): LogicalPlan = {
session.catalog.cacheTable(getCacheTable.getTableName)
emptyLocalRelation
}

private def transformUncacheTable(getUncacheTable: proto.UncacheTable): LogicalPlan = {
session.catalog.uncacheTable(getUncacheTable.getTableName)
emptyLocalRelation
}

private def transformClearCache(getClearCache: proto.ClearCache): LogicalPlan = {
session.catalog.clearCache()
Expand Down
15 changes: 12 additions & 3 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,9 @@ def isCached(self, tableName: str) -> bool:
.. versionadded:: 2.0.0
.. versionchanged:: 3.4.0
Support Spark Connect.
Parameters
----------
tableName : str
Expand All @@ -956,7 +959,7 @@ def isCached(self, tableName: str) -> bool:
Throw an analysis exception when the table does not exist.
>>> spark.catalog.isCached("not_existing_table")
>>> spark.catalog.isCached("not_existing_table") # doctest: +SKIP
Traceback (most recent call last):
...
AnalysisException: ...
Expand All @@ -975,6 +978,9 @@ def cacheTable(self, tableName: str) -> None:
.. versionadded:: 2.0.0
.. versionchanged:: 3.4.0
Support Spark Connect.
Parameters
----------
tableName : str
Expand All @@ -991,7 +997,7 @@ def cacheTable(self, tableName: str) -> None:
Throw an analysis exception when the table does not exist.
>>> spark.catalog.cacheTable("not_existing_table")
>>> spark.catalog.cacheTable("not_existing_table") # doctest: +SKIP
Traceback (most recent call last):
...
AnalysisException: ...
Expand All @@ -1009,6 +1015,9 @@ def uncacheTable(self, tableName: str) -> None:
.. versionadded:: 2.0.0
.. versionchanged:: 3.4.0
Support Spark Connect.
Parameters
----------
tableName : str
Expand All @@ -1028,7 +1037,7 @@ def uncacheTable(self, tableName: str) -> None:
Throw an analysis exception when the table does not exist.
>>> spark.catalog.uncacheTable("not_existing_table") # doctest: +IGNORE_EXCEPTION_DETAIL
>>> spark.catalog.uncacheTable("not_existing_table") # doctest: +SKIP
Traceback (most recent call last):
...
AnalysisException: ...
Expand Down
50 changes: 17 additions & 33 deletions python/pyspark/sql/connect/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,25 +270,22 @@ def dropGlobalTempView(self, viewName: str) -> bool:

dropGlobalTempView.__doc__ = PySparkCatalog.dropGlobalTempView.__doc__

# TODO(SPARK-41612): Support Catalog.isCached
# def isCached(self, tableName: str) -> bool:
# pdf = self._catalog_to_pandas(plan.IsCached(table_name=tableName))
# assert pdf is not None
# return pdf.iloc[0].iloc[0]
#
# isCached.__doc__ = PySparkCatalog.isCached.__doc__
#
# TODO(SPARK-41600): Support Catalog.cacheTable
# def cacheTable(self, tableName: str) -> None:
# self._catalog_to_pandas(plan.CacheTable(table_name=tableName))
#
# cacheTable.__doc__ = PySparkCatalog.cacheTable.__doc__
#
# TODO(SPARK-41623): Support Catalog.uncacheTable
# def uncacheTable(self, tableName: str) -> None:
# self._catalog_to_pandas(plan.UncacheTable(table_name=tableName))
#
# uncacheTable.__doc__ = PySparkCatalog.uncacheTable.__doc__
def isCached(self, tableName: str) -> bool:
pdf = self._catalog_to_pandas(plan.IsCached(table_name=tableName))
assert pdf is not None
return pdf.iloc[0].iloc[0]

isCached.__doc__ = PySparkCatalog.isCached.__doc__

def cacheTable(self, tableName: str) -> None:
self._catalog_to_pandas(plan.CacheTable(table_name=tableName))

cacheTable.__doc__ = PySparkCatalog.cacheTable.__doc__

def uncacheTable(self, tableName: str) -> None:
self._catalog_to_pandas(plan.UncacheTable(table_name=tableName))

uncacheTable.__doc__ = PySparkCatalog.uncacheTable.__doc__

def clearCache(self) -> None:
self._catalog_to_pandas(plan.ClearCache())
Expand All @@ -310,15 +307,6 @@ def refreshByPath(self, path: str) -> None:

refreshByPath.__doc__ = PySparkCatalog.refreshByPath.__doc__

def isCached(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("isCached() is not implemented.")

def cacheTable(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("cacheTable() is not implemented.")

def uncacheTable(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("uncacheTable() is not implemented.")

def registerFunction(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("registerFunction() is not implemented.")

Expand All @@ -337,11 +325,7 @@ def _test() -> None:
PySparkSession.builder.appName("sql.connect.catalog tests").remote("local[4]").getOrCreate()
)

# TODO(SPARK-41612): Support Catalog.isCached
# TODO(SPARK-41600): Support Catalog.cacheTable
del pyspark.sql.connect.catalog.Catalog.clearCache.__doc__
del pyspark.sql.connect.catalog.Catalog.refreshTable.__doc__
del pyspark.sql.connect.catalog.Catalog.refreshByPath.__doc__
# TODO(SPARK-41818): java.lang.ClassNotFoundException) .DefaultSource
del pyspark.sql.connect.catalog.Catalog.recoverPartitions.__doc__

(failure_count, test_count) = doctest.testmod(
Expand Down
74 changes: 38 additions & 36 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,45 +1729,47 @@ def __init__(self, table_name: str) -> None:
self._table_name = table_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(recover_partitions=proto.RecoverPartitions()))
plan.catalog.recover_partitions.table_name = self._table_name
plan = proto.Relation(
catalog=proto.Catalog(
recover_partitions=proto.RecoverPartitions(table_name=self._table_name)
)
)
return plan


# TODO(SPARK-41612): Support Catalog.isCached
# class IsCached(LogicalPlan):
# def __init__(self, table_name: str) -> None:
# super().__init__(None)
# self._table_name = table_name
#
# def plan(self, session: "SparkConnectClient") -> proto.Relation:
# plan = proto.Relation(catalog=proto.Catalog(is_cached=proto.IsCached()))
# plan.catalog.is_cached.table_name = self._table_name
# return plan
#
#
# TODO(SPARK-41600): Support Catalog.cacheTable
# class CacheTable(LogicalPlan):
# def __init__(self, table_name: str) -> None:
# super().__init__(None)
# self._table_name = table_name
#
# def plan(self, session: "SparkConnectClient") -> proto.Relation:
# plan = proto.Relation(catalog=proto.Catalog(cache_table=proto.CacheTable()))
# plan.catalog.cache_table.table_name = self._table_name
# return plan
#
#
# TODO(SPARK-41623): Support Catalog.uncacheTable
# class UncacheTable(LogicalPlan):
# def __init__(self, table_name: str) -> None:
# super().__init__(None)
# self._table_name = table_name
#
# def plan(self, session: "SparkConnectClient") -> proto.Relation:
# plan = proto.Relation(catalog=proto.Catalog(uncache_table=proto.UncacheTable()))
# plan.catalog.uncache_table.table_name = self._table_name
# return plan
class IsCached(LogicalPlan):
def __init__(self, table_name: str) -> None:
super().__init__(None)
self._table_name = table_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(
catalog=proto.Catalog(is_cached=proto.IsCached(table_name=self._table_name))
)
return plan


class CacheTable(LogicalPlan):
def __init__(self, table_name: str) -> None:
super().__init__(None)
self._table_name = table_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(
catalog=proto.Catalog(cache_table=proto.CacheTable(table_name=self._table_name))
)
return plan


class UncacheTable(LogicalPlan):
def __init__(self, table_name: str) -> None:
super().__init__(None)
self._table_name = table_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(uncache_table=proto.UncacheTable()))
plan.catalog.uncache_table.table_name = self._table_name
return plan


class ClearCache(LogicalPlan):
Expand Down
Loading

0 comments on commit 54b5cf6

Please sign in to comment.