Skip to content

Commit

Permalink
Separate out API's, this will facilitate adding write options later
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Jul 12, 2024
1 parent 554b8a5 commit c589736
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 40 deletions.
37 changes: 7 additions & 30 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4143,49 +4143,26 @@ class Dataset[T] private[sql](
* {{{
* spark.table("source")
* .update(
* Map("salary" -> lit(200)),
* $"salary" === 100
* )
* }}}
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @param condition the update condition.
* @group basic
* @since 4.0.0
*/
def update(assignments: Map[String, Column], condition: Column): Unit = {
if (isStreaming) {
logicalPlan.failAnalysis(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}

new UpdateWriter(this, assignments, Some(condition.expr)).update()
}

/**
* Update all rows in a target table.
* Map("salary" -> lit(200))
* )
* .where($"salary" === 100)
* .update()
*
* Scala Example:
* {{{
* spark.table("source")
* .update(Map(
* "salary" -> lit(200)
* ))
* }}}
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @param condition the update condition.
* @group basic
* @since 4.0.0
*/
def update(assignments: Map[String, Column]): Unit = {
def update(assignments: Map[String, Column]): UpdateWriter[T] = {
if (isStreaming) {
logicalPlan.failAnalysis(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}

new UpdateWriter(this, assignments, None).update()
new UpdateWriter(this, assignments)
}

/**
Expand Down
24 changes: 16 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,45 @@ import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable}
import org.apache.spark.sql.functions.expr

/**
* `UpdateWriter` provides methods to define and execute an update action based
* on the specified conditions.
* `UpdateWriter` provides methods to define and execute an update action on a target Dataset.
*
* @tparam T the type of data in the Dataset.
* @param ds the Dataset to update.
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @param condition the update condition.
*
* @since 4.0.0
*/
@Experimental
class UpdateWriter[T] private[sql](
ds: Dataset[T],
assignments: Map[String, Column],
condition: Option[Expression]) {
assignments: Map[String, Column]) {

private val df: DataFrame = ds.toDF()

private val sparkSession = ds.sparkSession

private val logicalPlan = df.queryExecution.logical

private var expression: Option[Expression] = None

/**
* Limits the update to rows matching the specified condition.
*
* @param condition the update condition
* @return
*/
def where(condition: Column): UpdateWriter[T] = {
this.expression = Some(condition.expr)
this
}

/**
* Executes the update operation.
*/
def update(): Unit = {
val update = UpdateTable(
logicalPlan,
assignments.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq,
condition)
expression)
val qe = sparkSession.sessionState.executePlan(update)
qe.assertCommandExecuted()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {
|""".stripMargin)

spark.table(tableNameAsString)
.update(Map("salary" -> lit(-1)), $"pk" >= 2)
.update(Map("salary" -> lit(-1)))
.where($"pk" >= 2)
.update()

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Expand All @@ -49,8 +51,9 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {
|{ "pk": 3, "salary": 120, "dep": 'hr' }
|""".stripMargin)

spark.table(tableNameAsString)
spark.read.option("a", value = true).table(tableNameAsString)
.update(Map("dep" -> lit("software")))
.update()

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Expand Down

0 comments on commit c589736

Please sign in to comment.