Skip to content

Commit

Permalink
Add RETURNING support for MariaDB 10.5.1+
Browse files Browse the repository at this point in the history
  • Loading branch information
mirromutth committed Jan 17, 2024
1 parent 8752f66 commit f2e3c2d
Show file tree
Hide file tree
Showing 26 changed files with 548 additions and 153 deletions.
4 changes: 2 additions & 2 deletions src/main/java/io/asyncer/r2dbc/mysql/Binding.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ PreparedExecuteMessage toExecuteMessage(int statementId, boolean immediate) {
return new PreparedExecuteMessage(statementId, immediate, drainValues());
}

PreparedTextQueryMessage toTextMessage(Query query) {
return new PreparedTextQueryMessage(query, drainValues());
PreparedTextQueryMessage toTextMessage(Query query, String returning) {
return new PreparedTextQueryMessage(query, returning, drainValues());
}

/**
Expand Down
17 changes: 8 additions & 9 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@
* An implementation of {@link Result} representing the results of a query against the MySQL database.
* <p>
* A {@link Segment} provided by this implementation may be both {@link UpdateCount} and {@link RowSegment},
* see also {@link MySqlOkSegment}. It's based on a {@link OkMessage}, when the {@code generatedKeyName} is
* not {@code null}.
* see also {@link MySqlOkSegment}.
*/
public final class MySqlResult implements Result {

Expand Down Expand Up @@ -155,14 +154,14 @@ public <T> Flux<T> flatMap(Function<Segment, ? extends Publisher<? extends T>> f
}

static MySqlResult toResult(boolean binary, Codecs codecs, ConnectionContext context,
@Nullable String generatedKeyName, Flux<ServerMessage> messages) {
@Nullable String syntheticKeyName, Flux<ServerMessage> messages) {
requireNonNull(codecs, "codecs must not be null");
requireNonNull(context, "context must not be null");
requireNonNull(messages, "messages must not be null");

return new MySqlResult(OperatorUtils.discardOnCancel(messages)
.doOnDiscard(ReferenceCounted.class, ReferenceCounted::release)
.handle(new MySqlSegments(binary, codecs, context, generatedKeyName)));
.handle(new MySqlSegments(binary, codecs, context, syntheticKeyName)));
}

private static final class MySqlMessage implements Message {
Expand Down Expand Up @@ -268,16 +267,16 @@ private static final class MySqlSegments implements BiConsumer<ServerMessage, Sy
private final ConnectionContext context;

@Nullable
private final String generatedKeyName;
private final String syntheticKeyName;

private MySqlRowMetadata rowMetadata;

private MySqlSegments(boolean binary, Codecs codecs, ConnectionContext context,
@Nullable String generatedKeyName) {
@Nullable String syntheticKeyName) {
this.binary = binary;
this.codecs = codecs;
this.context = context;
this.generatedKeyName = generatedKeyName;
this.syntheticKeyName = syntheticKeyName;
}

@Override
Expand Down Expand Up @@ -309,8 +308,8 @@ public void accept(ServerMessage message, SynchronousSink<Segment> sink) {

this.rowMetadata = MySqlRowMetadata.create(metadataMessages);
} else if (message instanceof OkMessage) {
Segment segment = generatedKeyName == null ? new MySqlUpdateCount((OkMessage) message) :
new MySqlOkSegment((OkMessage) message, codecs, generatedKeyName);
Segment segment = syntheticKeyName == null ? new MySqlUpdateCount((OkMessage) message) :
new MySqlOkSegment((OkMessage) message, codecs, syntheticKeyName);

sink.next(segment);
} else if (message instanceof ErrorMessage) {
Expand Down
80 changes: 70 additions & 10 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@

package io.asyncer.r2dbc.mysql;

import io.asyncer.r2dbc.mysql.internal.util.InternalArrays;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import org.jetbrains.annotations.Nullable;

import java.util.StringJoiner;

import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;
Expand All @@ -27,31 +31,87 @@
*/
abstract class MySqlStatementSupport implements MySqlStatement {

private static final ServerVersion MARIA_10_5_1 = ServerVersion.create(10, 5, 1, true);

private static final String LAST_INSERT_ID = "LAST_INSERT_ID";

protected final ConnectionContext context;

@Nullable
String generatedKeyName = null;
protected String[] generatedColumns = null;

MySqlStatementSupport(ConnectionContext context) {
this.context = requireNonNull(context, "context must not be null");
}

@Override
public final MySqlStatement returnGeneratedValues(String... columns) {
requireNonNull(columns, "columns must not be null");

switch (columns.length) {
case 0:
this.generatedKeyName = LAST_INSERT_ID;
return this;
case 1:
requireNonEmpty(columns[0], "id name must not be empty");
this.generatedKeyName = columns[0];
return this;
int len = columns.length;

if (len == 0) {
this.generatedColumns = InternalArrays.EMPTY_STRINGS;
} else if (len == 1 || supportReturning()) {
String[] result = new String[len];

for (int i = 0; i < len; ++i) {
requireNonEmpty(columns[i], "returning column must not be empty");
result[i] = columns[i];
}

this.generatedColumns = result;
} else {
String db = context.isMariaDb() ? "MariaDB 10.5.0 or below" : "MySQL";
throw new IllegalArgumentException(db + " supports only LAST_INSERT_ID instead of RETURNING");
}

throw new IllegalArgumentException("MySQL only supports single generated value");
return this;
}

@Override
public MySqlStatement fetchSize(int rows) {
require(rows >= 0, "Fetch size must be greater or equal to zero");
return this;
}

@Nullable
final String syntheticKeyName() {
String[] columns = this.generatedColumns;

// MariaDB should use `RETURNING` clause instead.
if (columns == null || supportReturning()) {
return null;
}

if (columns.length == 0) {
return LAST_INSERT_ID;
}

return columns[0];
}

final String returningIdentifiers() {
String[] columns = this.generatedColumns;

if (columns == null || !supportReturning()) {
return "";
}

if (columns.length == 0) {
return "*";
}

StringJoiner joiner = new StringJoiner(",");

for (String column : columns) {
joiner.add(StringUtils.quoteIdentifier(column));
}

return joiner.toString();
}

private boolean supportReturning() {
return context.isMariaDb() && context.getServerVersion().isGreaterThanOrEqualTo(MARIA_10_5_1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,18 @@ abstract class ParametrizedStatementSupport extends MySqlStatementSupport {

protected final Query query;

protected final ConnectionContext context;

private final Bindings bindings;

private final AtomicBoolean executed = new AtomicBoolean();

ParametrizedStatementSupport(Client client, Codecs codecs, Query query, ConnectionContext context) {
super(context);

requireNonNull(query, "query must not be null");
require(query.getParameters() > 0, "parameters must be a positive integer");

this.client = requireNonNull(client, "client must not be null");
this.codecs = requireNonNull(codecs, "codecs must not be null");
this.context = requireNonNull(context, "context must not be null");
this.query = query;
this.bindings = new Bindings(query.getParameters());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import reactor.core.publisher.Flux;

import java.util.List;
Expand All @@ -42,8 +43,11 @@ final class PrepareParametrizedStatement extends ParametrizedStatementSupport {

@Override
public Flux<MySqlResult> execute(List<Binding> bindings) {
return QueryFlow.execute(client, query.getFormattedSql(), bindings, fetchSize, prepareCache)
.map(messages -> MySqlResult.toResult(true, codecs, context, generatedKeyName, messages));
return Flux.defer(() -> QueryFlow.execute(client,
StringUtils.extendReturning(query.getFormattedSql(), returningIdentifiers()),
bindings, fetchSize, prepareCache
))
.map(messages -> MySqlResult.toResult(true, codecs, context, syntheticKeyName(), messages));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import reactor.core.publisher.Flux;

import java.util.Collections;
Expand All @@ -45,8 +46,9 @@ final class PrepareSimpleStatement extends SimpleStatementSupport {

@Override
public Flux<MySqlResult> execute() {
return QueryFlow.execute(client, sql, BINDINGS, fetchSize, prepareCache)
.map(messages -> MySqlResult.toResult(true, codecs, context, generatedKeyName, messages));
return Flux.defer(() -> QueryFlow.execute(client,
StringUtils.extendReturning(sql, returningIdentifiers()), BINDINGS, fetchSize, prepareCache))
.map(messages -> MySqlResult.toResult(true, codecs, context, syntheticKeyName(), messages));
}

@Override
Expand Down
46 changes: 27 additions & 19 deletions src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ final class QueryFlow {
* by {@link CompleteMessage} after receive the last result for the last binding.
*
* @param client the {@link Client} to exchange messages with.
* @param sql the original statement for exception tracing.
* @param sql the statement for exception tracing.
* @param bindings the data of bindings.
* @param fetchSize the size of fetching, if it less than or equal to {@literal 0} means fetch all rows.
* @param cache the cache of server-preparing result.
Expand All @@ -129,18 +129,21 @@ static Flux<Flux<ServerMessage>> execute(Client client, String sql, List<Binding
* will emit an exception and cancel subsequent {@link Binding}s. This exchange will be completed by
* {@link CompleteMessage} after receive the last result for the last binding.
*
* @param client the {@link Client} to exchange messages with.
* @param query the {@link Query} for synthetic client-preparing statement.
* @param bindings the data of bindings.
* @param client the {@link Client} to exchange messages with.
* @param query the {@link Query} for synthetic client-preparing statement.
* @param returning the {@code RETURNING} identifiers.
* @param bindings the data of bindings.
* @return the messages received in response to this exchange.
*/
static Flux<Flux<ServerMessage>> execute(Client client, Query query, List<Binding> bindings) {
static Flux<Flux<ServerMessage>> execute(
Client client, Query query, String returning, List<Binding> bindings
) {
return Flux.defer(() -> {
if (bindings.isEmpty()) {
return Flux.empty();
}

return client.exchange(new TextQueryExchangeable(query, bindings.iterator()))
return client.exchange(new TextQueryExchangeable(query, returning, bindings.iterator()))
.windowUntil(RESULT_DONE);
});
}
Expand Down Expand Up @@ -195,7 +198,7 @@ static Flux<Flux<ServerMessage>> execute(Client client, List<String> statements)
* @return the messages received in response to the login exchange.
*/
static Mono<Client> login(Client client, SslMode sslMode, String database, String user,
@Nullable CharSequence password, ConnectionContext context) {
@Nullable CharSequence password, ConnectionContext context) {
return client.exchange(new LoginExchangeable(client, sslMode, database, user, password, context))
.onErrorResume(e -> client.forceClose().then(Mono.error(e)))
.then(Mono.just(client));
Expand Down Expand Up @@ -263,13 +266,14 @@ static Mono<Void> beginTransaction(Client client, ConnectionState state, boolean
* Commits or rollbacks current transaction. It will recover statuses of the {@link ConnectionState} in
* the initial connection state.
*
* @param client the {@link Client} to exchange messages with.
* @param state the connection state for checks and resets transaction statuses.
* @param commit if commit, otherwise rollback.
* @param batchSupported if connection supports batch query.
* @param client the {@link Client} to exchange messages with.
* @param state the connection state for checks and resets transaction statuses.
* @param commit if commit, otherwise rollback.
* @param batchSupported if connection supports batch query.
* @return receives complete signal.
*/
static Mono<Void> doneTransaction(Client client, ConnectionState state, boolean commit, boolean batchSupported) {
static Mono<Void> doneTransaction(Client client, ConnectionState state, boolean commit,
boolean batchSupported) {
final CommitRollbackState commitState = new CommitRollbackState(state, commit);

if (batchSupported) {
Expand All @@ -279,7 +283,8 @@ static Mono<Void> doneTransaction(Client client, ConnectionState state, boolean
return client.exchange(new TransactionMultiExchangeable(commitState)).then();
}

static Mono<Void> createSavepoint(Client client, ConnectionState state, String name, boolean batchSupported) {
static Mono<Void> createSavepoint(Client client, ConnectionState state, String name,
boolean batchSupported) {
final CreateSavepointState savepointState = new CreateSavepointState(state, name);
if (batchSupported) {
return client.exchange(new TransactionBatchExchangeable(savepointState)).then();
Expand Down Expand Up @@ -357,10 +362,13 @@ final class TextQueryExchangeable extends BaseFluxExchangeable {

private final Query query;

private final String returning;

private final Iterator<Binding> bindings;

TextQueryExchangeable(Query query, Iterator<Binding> bindings) {
TextQueryExchangeable(Query query, String returning, Iterator<Binding> bindings) {
this.query = query;
this.returning = returning;
this.bindings = bindings;
}

Expand All @@ -384,9 +392,9 @@ public boolean isDisposed() {
@Override
protected void tryNextOrComplete(@Nullable SynchronousSink<ServerMessage> sink) {
if (this.bindings.hasNext()) {
QueryLogger.log(this.query);
QueryLogger.log(this.query, this.returning);

PreparedTextQueryMessage message = this.bindings.next().toTextMessage(this.query);
PreparedTextQueryMessage message = this.bindings.next().toTextMessage(this.query, this.returning);
Sinks.EmitResult result = this.requests.tryEmitNext(message);

if (result == Sinks.EmitResult.OK) {
Expand All @@ -404,7 +412,7 @@ protected void tryNextOrComplete(@Nullable SynchronousSink<ServerMessage> sink)

@Override
protected String offendingSql() {
return query.getFormattedSql();
return StringUtils.extendReturning(query.getFormattedSql(), returning);
}
}

Expand Down Expand Up @@ -1153,7 +1161,8 @@ protected boolean process(int task, SynchronousSink<Void> sink) {
}
return true;
case ISOLATION_LEVEL:
final IsolationLevel isolationLevel = definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL);
final IsolationLevel isolationLevel =
definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL);
if (isolationLevel != null) {
state.setIsolationLevel(isolationLevel);
}
Expand Down Expand Up @@ -1254,7 +1263,6 @@ protected boolean process(int task, SynchronousSink<Void> sink) {

final class TransactionBatchExchangeable extends FluxExchangeable<Void> {


private final AbstractTransactionState state;

TransactionBatchExchangeable(AbstractTransactionState state) {
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.asyncer.r2dbc.mysql;


import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

Expand All @@ -31,9 +32,10 @@ static void log(String query) {
logger.debug("Executing direct query: {}", query);
}

static void log(Query query) {
static void log(Query query, String returning) {
if (logger.isDebugEnabled()) {
logger.debug("Executing format query: {}", query.getFormattedSql());
logger.debug("Executing format query: {}",
StringUtils.extendReturning(query.getFormattedSql(), returning));
}
}

Expand Down
Loading

0 comments on commit f2e3c2d

Please sign in to comment.