diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bb50d17b422ff..fa7f8107d5e91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -4143,49 +4143,26 @@ class Dataset[T] private[sql]( * {{{ * spark.table("source") * .update( - * Map("salary" -> lit(200)), - * $"salary" === 100 - * ) - * }}} - * @param assignments A Map of column names to Column expressions representing the updates - * to be applied. - * @param condition the update condition. - * @group basic - * @since 4.0.0 - */ - def update(assignments: Map[String, Column], condition: Column): Unit = { - if (isStreaming) { - logicalPlan.failAnalysis( - errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", - messageParameters = Map("methodName" -> toSQLId("update"))) - } - - new UpdateWriter(this, assignments, Some(condition.expr)).update() - } - - /** - * Update all rows in a target table. + * Map("salary" -> lit(200)) + * ) + * .where($"salary" === 100) + * .update() * - * Scala Example: - * {{{ - * spark.table("source") - * .update(Map( - * "salary" -> lit(200) - * )) * }}} * @param assignments A Map of column names to Column expressions representing the updates * to be applied. + * @param condition the update condition. * @group basic * @since 4.0.0 */ - def update(assignments: Map[String, Column]): Unit = { + def update(assignments: Map[String, Column]): UpdateWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", messageParameters = Map("methodName" -> toSQLId("update"))) } - new UpdateWriter(this, assignments, None).update() + new UpdateWriter(this, assignments) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala index 8267fac420a92..40d558c3e3ed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala @@ -23,29 +23,37 @@ import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} import org.apache.spark.sql.functions.expr /** - * `UpdateWriter` provides methods to define and execute an update action based - * on the specified conditions. + * `UpdateWriter` provides methods to define and execute an update action on a target Dataset. * * @tparam T the type of data in the Dataset. * @param ds the Dataset to update. * @param assignments A Map of column names to Column expressions representing the updates * to be applied. - * @param condition the update condition. * * @since 4.0.0 */ @Experimental class UpdateWriter[T] private[sql]( ds: Dataset[T], - assignments: Map[String, Column], - condition: Option[Expression]) { + assignments: Map[String, Column]) { private val df: DataFrame = ds.toDF() - private val sparkSession = ds.sparkSession - private val logicalPlan = df.queryExecution.logical + private var expression: Option[Expression] = None + + /** + * Limits the update to rows matching the specified condition. + * + * @param condition the update condition + * @return + */ + def where(condition: Column): UpdateWriter[T] = { + this.expression = Some(condition.expr) + this + } + /** * Executes the update operation. */ @@ -53,7 +61,7 @@ class UpdateWriter[T] private[sql]( val update = UpdateTable( logicalPlan, assignments.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq, - condition) + expression) val qe = sparkSession.sessionState.executePlan(update) qe.assertCommandExecuted() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala index 6484190bc4d24..43734bfe23790 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala @@ -32,7 +32,9 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase { |""".stripMargin) spark.table(tableNameAsString) - .update(Map("salary" -> lit(-1)), $"pk" >= 2) + .update(Map("salary" -> lit(-1))) + .where($"pk" >= 2) + .update() checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -49,8 +51,9 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase { |{ "pk": 3, "salary": 120, "dep": 'hr' } |""".stripMargin) - spark.table(tableNameAsString) + spark.read.option("a", value = true).table(tableNameAsString) .update(Map("dep" -> lit("software"))) + .update() checkAnswer( sql(s"SELECT * FROM $tableNameAsString"),