Skip to content

Commit

Permalink
Transaction States Should Be Checked In Queue
Browse files Browse the repository at this point in the history
Motivation:
Currently, `MySqlConnection` checks state between the `Mono.defer` is subscribed and `Exchangeable` is executed. It may cause undefined behavior.

Modification:
Checks transaction state when request queue executes task.

Result:
Resolves #183
  • Loading branch information
jchrys committed Dec 29, 2023
1 parent 7bbe8af commit eb7d65c
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 37 deletions.
7 changes: 7 additions & 0 deletions src/main/java/io/asyncer/r2dbc/mysql/ConnectionState.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ interface ConnectionState {
*/
void setIsolationLevel(IsolationLevel level);

/**
* Reutrns session lock wait timeout.
*
* @return Session lock wait timeout.
*/
long getSessionLockWaitTimeout();

/**
* Sets current lock wait timeout.
*
Expand Down
23 changes: 8 additions & 15 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ public Mono<Void> close() {
@Override
public Mono<Void> commitTransaction() {
return Mono.defer(() -> {
return QueryFlow.doneTransaction(client, this, true, lockWaitTimeout, batchSupported);
return QueryFlow.doneTransaction(client, this, true, batchSupported);
});
}

Expand All @@ -223,19 +223,7 @@ public MySqlBatch createBatch() {
@Override
public Mono<Void> createSavepoint(String name) {
requireValidName(name, "Savepoint name must not be empty and not contain backticks");

String sql = String.format("SAVEPOINT `%s`", name);

return Mono.defer(() -> {
if (isInTransaction()) {
return QueryFlow.executeVoid(client, sql);
} else if (batchSupported) {
// If connection does not in transaction, then starts transaction.
return QueryFlow.executeVoid(client, "BEGIN;" + sql);
}

return QueryFlow.executeVoid(client, "BEGIN", sql);
});
return QueryFlow.createSavepoint(client, this, name, batchSupported);
}

@Override
Expand Down Expand Up @@ -286,7 +274,7 @@ public Mono<Void> releaseSavepoint(String name) {
@Override
public Mono<Void> rollbackTransaction() {
return Mono.defer(() -> {
return QueryFlow.doneTransaction(client, this, false, lockWaitTimeout, batchSupported);
return QueryFlow.doneTransaction(client, this, false, batchSupported);
});
}

Expand Down Expand Up @@ -371,6 +359,11 @@ public void setIsolationLevel(IsolationLevel level) {
this.currentLevel = level;
}

@Override
public long getSessionLockWaitTimeout() {
return lockWaitTimeout;
}

@Override
public void setCurrentLockWaitTimeout(long timeoutSeconds) {
this.currentLockWaitTimeout = timeoutSeconds;
Expand Down
107 changes: 85 additions & 22 deletions src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.function.Supplier;

/**
* A message flow considers both of parametrized and text queries, such as {@link TextParametrizedStatement},
Expand Down Expand Up @@ -249,13 +250,13 @@ static Mono<Void> executeVoid(Client client, String... statements) {
*/
static Mono<Void> beginTransaction(Client client, ConnectionState state, boolean batchSupported,
TransactionDefinition definition) {
StartTransactionState startState = StartTransactionState.of(state, definition);
Supplier<AbstractTransactionState> startStateSupplier = () -> StartTransactionState.of(state, definition);

if (batchSupported || startState.isSimple()) {
return client.exchange(new TransactionBatchExchangeable(startState)).then();
if (batchSupported) {
return client.exchange(new TransactionBatchExchangeable(startStateSupplier)).then();
}

return client.exchange(new TransactionMultiExchangeable(startState)).then();
return client.exchange(new TransactionMultiExchangeable(startStateSupplier)).then();
}

/**
Expand All @@ -265,19 +266,25 @@ static Mono<Void> beginTransaction(Client client, ConnectionState state, boolean
* @param client the {@link Client} to exchange messages with.
* @param state the connection state for checks and resets transaction statuses.
* @param commit if commit, otherwise rollback.
* @param lockWaitTimeout the lock wait timeout of the initial connection state.
* @param batchSupported if connection supports batch query.
* @return receives complete signal.
*/
static Mono<Void> doneTransaction(Client client, ConnectionState state, boolean commit,
long lockWaitTimeout, boolean batchSupported) {
CommitRollbackState commitState = CommitRollbackState.of(state, commit, lockWaitTimeout);
static Mono<Void> doneTransaction(Client client, ConnectionState state, boolean commit, boolean batchSupported) {
Supplier<AbstractTransactionState> commitStateSupplier = () -> CommitRollbackState.of(state, commit);

if (batchSupported || commitState.isSimple()) {
return client.exchange(new TransactionBatchExchangeable(commitState)).then();
if (batchSupported) {
return client.exchange(new TransactionBatchExchangeable(commitStateSupplier)).then();
}

return client.exchange(new TransactionMultiExchangeable(commitState)).then();
return client.exchange(new TransactionMultiExchangeable(commitStateSupplier)).then();
}

static Mono<Void> createSavepoint(Client client, ConnectionState state, String name, boolean batchSupported) {
Supplier<AbstractTransactionState> commitStateSupplier = () -> CreateSavepointState.of(state, name);
if (batchSupported) {
return client.exchange(new TransactionBatchExchangeable(commitStateSupplier)).then();
}
return client.exchange(new TransactionMultiExchangeable(commitStateSupplier)).then();
}

/**
Expand Down Expand Up @@ -1082,13 +1089,13 @@ protected boolean process(int task, SynchronousSink<Void> sink) {
return false;
}

static CommitRollbackState of(ConnectionState state, boolean commit, long lockWaitTimeout) {
static CommitRollbackState of(ConnectionState state, boolean commit) {
String doneSql = commit ? "COMMIT" : "ROLLBACK";

if (state.isLockWaitTimeoutChanged()) {
List<String> statements = new ArrayList<>(2);

statements.add("SET innodb_lock_wait_timeout=" + lockWaitTimeout);
statements.add("SET innodb_lock_wait_timeout=" + state.getSessionLockWaitTimeout());
statements.add(doneSql);

return new CommitRollbackState(state, LOCK_WAIT_TIMEOUT | COMMIT_OR_ROLLBACK, statements);
Expand Down Expand Up @@ -1227,16 +1234,62 @@ private static String buildStartTransaction(TransactionDefinition definition) {
}
}

final class CreateSavepointState extends AbstractTransactionState {

private static final int START_TRANSACTION = 1;

private static final int CREATE_SAVEPOINT = 2;

private CreateSavepointState(ConnectionState state, int tasks, List<String> statements) {
super(state, tasks, statements);
}

@Override
boolean cancelTasks() {
return false;
}

@Override
protected boolean process(int task, SynchronousSink<Void> sink) {
switch (task) {
case START_TRANSACTION:
return true;
case CREATE_SAVEPOINT:
sink.complete();
return false;
}
return false;
}

static CreateSavepointState of(ConnectionState state, final String name) {
int tasks = 0;
final String doneSql = String.format("SAVEPOINT `%s`", name);

if (!state.isInTransaction()) {
List<String> statements = new ArrayList<>(2);
statements.add("BEGIN");
statements.add(String.format("SAVEPOINT `%s`", name));
return new CreateSavepointState(state, START_TRANSACTION | CREATE_SAVEPOINT, statements);
}
return new CreateSavepointState(state, CREATE_SAVEPOINT, Collections.singletonList(doneSql));
}
}

final class TransactionBatchExchangeable extends FluxExchangeable<Void> {

private final AbstractTransactionState state;

TransactionBatchExchangeable(AbstractTransactionState state) {
this.state = state;
private final Supplier<AbstractTransactionState> stateSupplier;

@Nullable
private AbstractTransactionState state;

TransactionBatchExchangeable(Supplier<AbstractTransactionState> stateSupplier) {
this.stateSupplier = stateSupplier;
}

@Override
public void accept(ServerMessage message, SynchronousSink<Void> sink) {
assert state != null;
state.accept(message, sink);
}

Expand All @@ -1247,9 +1300,11 @@ public void dispose() {

@Override
public void subscribe(CoreSubscriber<? super ClientMessage> s) {
assert state == null;
state = stateSupplier.get();

if (state.cancelTasks()) {
s.onSubscribe(Operators.scalarSubscription(s, PingMessage.INSTANCE));

return;
}

Expand All @@ -1266,17 +1321,21 @@ final class TransactionMultiExchangeable extends FluxExchangeable<Void> {
private final Sinks.Many<ClientMessage> requests = Sinks.many().unicast()
.onBackpressureBuffer(Queues.<ClientMessage>one().get());

private final AbstractTransactionState state;
private final Supplier<AbstractTransactionState> stateSupplier;

private final Iterator<String> statements;
@Nullable
private AbstractTransactionState state;

TransactionMultiExchangeable(AbstractTransactionState state) {
this.state = state;
this.statements = state.statements();
@Nullable
private Iterator<String> statements;

TransactionMultiExchangeable(Supplier<AbstractTransactionState> stateSupplier) {
this.stateSupplier = stateSupplier;
}

@Override
public void accept(ServerMessage message, SynchronousSink<Void> sink) {
assert state != null && statements != null;
if (state.accept(message, sink)) {
String sql = statements.next();

Expand All @@ -1300,6 +1359,10 @@ public void dispose() {

@Override
public void subscribe(CoreSubscriber<? super ClientMessage> s) {
assert state == null && statements == null;
state = stateSupplier.get();
statements = state.statements();

if (state.cancelTasks()) {
s.onSubscribe(Operators.scalarSubscription(s, PingMessage.INSTANCE));

Expand Down

0 comments on commit eb7d65c

Please sign in to comment.