diff --git a/exposed/src/main/kotlin/org/jetbrains/exposed/sql/transactions/ThreadLocalTransactionManager.kt b/exposed/src/main/kotlin/org/jetbrains/exposed/sql/transactions/ThreadLocalTransactionManager.kt index 1af254ed2a..ee88cf162b 100644 --- a/exposed/src/main/kotlin/org/jetbrains/exposed/sql/transactions/ThreadLocalTransactionManager.kt +++ b/exposed/src/main/kotlin/org/jetbrains/exposed/sql/transactions/ThreadLocalTransactionManager.kt @@ -108,23 +108,24 @@ fun transaction(transactionIsolation: Int, repetitionAttempts: Int, db: Data if (outer != null && (db == null || outer.db == db)) { val outerManager = outer.db.transactionManager - db?.let { db.transactionManager.let { m -> TransactionManager.resetCurrent(m) } } - val transaction = db.transactionManager.newTransaction(transactionIsolation, outer) + val transaction = outerManager.newTransaction(transactionIsolation, outer) try { transaction.statement().also { - transaction.commit() + if(outer.db.useNestedTransactions) + transaction.commit() } } finally { TransactionManager.resetCurrent(outerManager) } } else { - val existingForDb = db?.let { db.transactionManager } + val existingForDb = db?.transactionManager existingForDb?.currentOrNull()?.let { transaction -> val currentManager = outer?.db.transactionManager try { TransactionManager.resetCurrent(existingForDb) transaction.statement().also { - transaction.commit() + if(db.useNestedTransactions) + transaction.commit() } } finally { TransactionManager.resetCurrent(currentManager) @@ -198,7 +199,7 @@ fun inTopLevelTransaction( internal fun keepAndRestoreTransactionRefAfterRun(db: Database? = null, block: () -> T): T { val manager = db.transactionManager as? ThreadLocalTransactionManager - val currentTransaction = manager?.threadLocal?.get() + val currentTransaction = manager?.currentOrNull() return block().also { manager?.threadLocal?.set(currentTransaction) } diff --git a/exposed/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/ThreadLocalManagerTest.kt b/exposed/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/ThreadLocalManagerTest.kt index 57f0f20ecf..c519d94fd2 100644 --- a/exposed/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/ThreadLocalManagerTest.kt +++ b/exposed/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/ThreadLocalManagerTest.kt @@ -2,13 +2,13 @@ package org.jetbrains.exposed.sql.tests.shared import org.hamcrest.MatcherAssert.assertThat import org.hamcrest.Matchers +import org.jetbrains.exposed.dao.IntIdTable import org.jetbrains.exposed.exceptions.ExposedSQLException -import org.jetbrains.exposed.sql.Database -import org.jetbrains.exposed.sql.SchemaUtils -import org.jetbrains.exposed.sql.selectAll +import org.jetbrains.exposed.sql.* import org.jetbrains.exposed.sql.tests.DatabaseTestsBase import org.jetbrains.exposed.sql.tests.TestDB import org.jetbrains.exposed.sql.transactions.TransactionManager +import org.jetbrains.exposed.sql.transactions.inTopLevelTransaction import org.jetbrains.exposed.sql.transactions.transaction import org.jetbrains.exposed.sql.transactions.transactionManager import org.junit.After @@ -316,4 +316,59 @@ class MultipleDatabaseBugTest { println("Transaction connection url: ${connection.metaData?.url}") } } +} + +object RollbackTable : IntIdTable() { + val value = varchar("value", 20) +} + +class RollbackTransactionTest : DatabaseTestsBase() { + + @Test + fun testRollbackWithoutSavepoints() { + withTables(RollbackTable) { + inTopLevelTransaction(db.transactionManager.defaultIsolationLevel, 1) { + RollbackTable.insert { it[value] = "before-dummy" } + transaction { + assertEquals(1, RollbackTable.select { RollbackTable.value eq "before-dummy" }.count()) + RollbackTable.insert { it[value] = "inner-dummy" } + } + assertEquals(1, RollbackTable.select { RollbackTable.value eq "before-dummy" }.count()) + assertEquals(1, RollbackTable.select { RollbackTable.value eq "inner-dummy" }.count()) + RollbackTable.insert { it[value] = "after-dummy" } + assertEquals(1, RollbackTable.select { RollbackTable.value eq "after-dummy" }.count()) + rollback() + } + assertEquals(0, RollbackTable.select { RollbackTable.value eq "before-dummy" }.count()) + assertEquals(0, RollbackTable.select { RollbackTable.value eq "inner-dummy" }.count()) + assertEquals(0, RollbackTable.select { RollbackTable.value eq "after-dummy" }.count()) + } + } + + @Test + fun testRollbackWithSavepoints() { + withTables(RollbackTable) { + try { + db.useNestedTransactions = true + inTopLevelTransaction(db.transactionManager.defaultIsolationLevel, 1) { + RollbackTable.insert { it[value] = "before-dummy" } + transaction { + assertEquals(1, RollbackTable.select { RollbackTable.value eq "before-dummy" }.count()) + RollbackTable.insert { it[value] = "inner-dummy" } + rollback() + } + assertEquals(1, RollbackTable.select { RollbackTable.value eq "before-dummy" }.count()) + assertEquals(0, RollbackTable.select { RollbackTable.value eq "inner-dummy" }.count()) + RollbackTable.insert { it[value] = "after-dummy" } + assertEquals(1, RollbackTable.select { RollbackTable.value eq "after-dummy" }.count()) + rollback() + } + assertEquals(0, RollbackTable.select { RollbackTable.value eq "before-dummy" }.count()) + assertEquals(0, RollbackTable.select { RollbackTable.value eq "inner-dummy" }.count()) + assertEquals(0, RollbackTable.select { RollbackTable.value eq "after-dummy" }.count()) + } finally { + db.useNestedTransactions = false + } + } + } } \ No newline at end of file