Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify use of transactions when performing overwrites and when creating new tables #157

Closed
wants to merge 10 commits into from
86 changes: 29 additions & 57 deletions src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.TaskContext
import org.slf4j.LoggerFactory

import scala.collection.mutable
import scala.util.Random
import scala.util.control.NonFatal

import com.databricks.spark.redshift.Parameters.MergedParameters
Expand All @@ -48,6 +47,8 @@ import org.apache.spark.sql.types._
* non-empty. After the write operation completes, we use this to construct a list of non-empty
* Avro partition files.
*
* - Using JDBC, start a new tra
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: this entire block comment needs to be updated.

*
* - Use JDBC to issue any CREATE TABLE commands, if required.
*
* - If there is data to be written (i.e. not all partitions were empty), then use the list of
Expand Down Expand Up @@ -102,60 +103,15 @@ private[redshift] class RedshiftWriter(
}

/**
* Sets up a staging table then runs the given action, passing the temporary table name
* as a parameter.
*/
private def withStagingTable(
conn: Connection,
table: TableName,
action: (String) => Unit) {
val randomSuffix = Math.abs(Random.nextInt()).toString
val tempTable =
table.copy(unescapedTableName = s"${table.unescapedTableName}_staging_$randomSuffix")
val backupTable =
table.copy(unescapedTableName = s"${table.unescapedTableName}_backup_$randomSuffix")
log.info("Loading new Redshift data to: " + tempTable)
log.info("Existing data will be backed up in: " + backupTable)

try {
action(tempTable.toString)

if (jdbcWrapper.tableExists(conn, table.toString)) {
jdbcWrapper.executeInterruptibly(conn.prepareStatement(
s"""
| BEGIN;
| ALTER TABLE $table RENAME TO ${backupTable.escapedTableName};
| ALTER TABLE $tempTable RENAME TO ${table.escapedTableName};
| DROP TABLE $backupTable;
| END;
""".stripMargin.trim))
} else {
jdbcWrapper.executeInterruptibly(conn.prepareStatement(
s"ALTER TABLE $tempTable RENAME TO ${table.escapedTableName}"))
}
} finally {
jdbcWrapper.executeInterruptibly(conn.prepareStatement(s"DROP TABLE IF EXISTS $tempTable"))
}
}

/**
* Perform the Redshift load, including deletion of existing data in the case of an overwrite,
* and creating the table if it doesn't already exist.
* Perform the Redshift load by issuing a COPY statement.
*/
private def doRedshiftLoad(
conn: Connection,
data: DataFrame,
saveMode: SaveMode,
params: MergedParameters,
creds: AWSCredentials,
manifestUrl: Option[String]): Unit = {

// Overwrites must drop the table, in case there has been a schema update
if (saveMode == SaveMode.Overwrite) {
jdbcWrapper.executeInterruptibly(
conn.prepareStatement(s"DROP TABLE IF EXISTS ${params.table.get}"))
}

// If the table doesn't exist, we need to create it first, using JDBC to infer column types
val createStatement = createTableSql(data, params)
log.info(createStatement)
Expand Down Expand Up @@ -360,19 +316,35 @@ private[redshift] class RedshiftWriter(

Utils.checkThatBucketHasObjectLifecycleConfiguration(params.rootTempDir, s3ClientFactory(creds))

val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl)
// Save the table's rows to S3:
val manifestUrl = unloadData(sqlContext, data, params.createPerQueryTempDir())

val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl)
conn.setAutoCommit(false)
try {
val tempDir = params.createPerQueryTempDir()
val manifestUrl = unloadData(sqlContext, data, tempDir)
if (saveMode == SaveMode.Overwrite && params.useStagingTable) {
withStagingTable(conn, params.table.get, stagingTable => {
val updatedParams = MergedParameters(params.parameters.updated("dbtable", stagingTable))
doRedshiftLoad(conn, data, saveMode, updatedParams, creds, manifestUrl)
})
} else {
doRedshiftLoad(conn, data, saveMode, params, creds, manifestUrl)
val table: TableName = params.table.get
if (saveMode == SaveMode.Overwrite) {
// Overwrites must drop the table in case there has been a schema update
jdbcWrapper.executeInterruptibly(conn.prepareStatement(s"DROP TABLE IF EXISTS $table;"))
if (!params.useStagingTable) {
// If we're not using a staging table, commit now so that Redshift doesn't have to
// maintain a snapshot of the old table during the COPY; this sacrifices atomicity for
// performance.
conn.commit()
}
}
log.info(s"Loading new Redshift data to: $table")
doRedshiftLoad(conn, data, params, creds, manifestUrl)
conn.commit()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to handle InterruptedException? i.e. can this block for a really long time?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been my experience that virtually any Redshift command, COMMIT included, can block for several minutes under certain circumstances. I think the most likely cause is that the cluster has WLM parameters configured to put the connected client on a limited pool of some type, such that all commands will be queued when all slots are taken by other queries.

I have a feeling that means though, if you plan to send a ROLLBACK command in response to the interruption, that will also block for many minutes...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that it's safe for us to wrap this in order to catch InterruptedException since I don't think that it's safe to call .rollback() while the other thread is in the middle of executing commit(). Therefore, I'm going to leave this as-is for now and will revisit later if this turns out to be a problem in practice.

} catch {
case NonFatal(e) =>
try {
conn.rollback()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might a log.info or log.warn be good here, to inform anyone testing their application code that the load failed and what they're now waiting for is a rollback to finish?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

} catch {
case NonFatal(e2) =>
log.error("Exception while rolling back transaction", e2)
}
throw e
} finally {
conn.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class MockRedshift(
}
}

def verifyThatCommitWasNotCalled(): Unit = {
jdbcConnections.foreach { conn =>
verify(conn, never()).commit()
}
}

def verifyThatExpectedQueriesWereIssued(expectedQueries: Seq[Regex]): Unit = {
expectedQueries.zip(queriesIssued).foreach { case (expected, actual) =>
if (expected.findFirstMatchIn(actual).isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,19 +266,11 @@ class RedshiftSourceSuite
"distkey" -> "testint")

val expectedCommands = Seq(
"DROP TABLE IF EXISTS \"PUBLIC\"\\.\"test_table_staging_.*\"".r,
("CREATE TABLE IF NOT EXISTS \"PUBLIC\"\\.\"test_table_staging.*" +
"DROP TABLE IF EXISTS \"PUBLIC\"\\.\"test_table.*\"".r,
("CREATE TABLE IF NOT EXISTS \"PUBLIC\"\\.\"test_table.*" +
" DISTSTYLE KEY DISTKEY \\(testint\\).*").r,
"COPY \"PUBLIC\"\\.\"test_table_staging_.*\"".r,
"GRANT SELECT ON \"PUBLIC\"\\.\"test_table_staging.+\" TO jeremy".r,
"""
| BEGIN;
| ALTER TABLE "PUBLIC"\."test_table" RENAME TO "test_table_backup_.*";
| ALTER TABLE "PUBLIC"\."test_table_staging_.*" RENAME TO "test_table";
| DROP TABLE "PUBLIC"\."test_table_backup_.*";
| END;
""".stripMargin.trim.r,
"DROP TABLE IF EXISTS \"PUBLIC\"\\.\"test_table_staging_.*\"".r)
"COPY \"PUBLIC\"\\.\"test_table.*\"".r,
"GRANT SELECT ON \"PUBLIC\"\\.\"test_table\" TO jeremy".r)

val mockRedshift = new MockRedshift(
defaultParams("url"),
Expand Down Expand Up @@ -316,6 +308,7 @@ class RedshiftSourceSuite
testSqlContext, df, SaveMode.Append, Parameters.mergeParameters(defaultParams))
}
mockRedshift.verifyThatConnectionsWereClosed()
mockRedshift.verifyThatCommitWasNotCalled()
mockRedshift.verifyThatExpectedQueriesWereIssued(Seq.empty)
}

Expand All @@ -325,23 +318,21 @@ class RedshiftSourceSuite
val mockRedshift = new MockRedshift(
defaultParams("url"),
Map(TableName.parseFromEscaped("test_table").toString -> TestUtils.testSchema),
jdbcQueriesThatShouldFail = Seq("COPY \"PUBLIC\".\"test_table_staging_.*\"".r))
jdbcQueriesThatShouldFail = Seq("COPY \"PUBLIC\".\"test_table.*\"".r))

val expectedCommands = Seq(
"DROP TABLE IF EXISTS \"PUBLIC\".\"test_table_staging_.*\"".r,
"CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table_staging_.*\"".r,
"COPY \"PUBLIC\".\"test_table_staging_.*\"".r,
".*FROM stl_load_errors.*".r,
"DROP TABLE IF EXISTS \"PUBLIC\".\"test_table_staging_.*\"".r
"DROP TABLE IF EXISTS \"PUBLIC\".\"test_table.*\"".r,
"CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table.*\"".r,
"COPY \"PUBLIC\".\"test_table.*\"".r,
".*FROM stl_load_errors.*".r
)

val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
intercept[Exception] {
source.createRelation(testSqlContext, SaveMode.Overwrite, params, expectedDataDF)
mockRedshift.verifyThatConnectionsWereClosed()
mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
}
mockRedshift.verifyThatConnectionsWereClosed()
mockRedshift.verifyThatCommitWasNotCalled()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also verify that rollback was called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
}

Expand Down