diff --git a/src/main/java/io/asyncer/r2dbc/mysql/Binding.java b/src/main/java/io/asyncer/r2dbc/mysql/Binding.java
index 20fafe5ea..8dd64d7ec 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/Binding.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/Binding.java
@@ -70,8 +70,8 @@ PreparedExecuteMessage toExecuteMessage(int statementId, boolean immediate) {
return new PreparedExecuteMessage(statementId, immediate, drainValues());
}
- PreparedTextQueryMessage toTextMessage(Query query) {
- return new PreparedTextQueryMessage(query, drainValues());
+ PreparedTextQueryMessage toTextMessage(Query query, String returning) {
+ return new PreparedTextQueryMessage(query, returning, drainValues());
}
/**
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java
index 2b09277f5..943987d1e 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java
@@ -51,8 +51,7 @@
* An implementation of {@link Result} representing the results of a query against the MySQL database.
*
* A {@link Segment} provided by this implementation may be both {@link UpdateCount} and {@link RowSegment},
- * see also {@link MySqlOkSegment}. It's based on a {@link OkMessage}, when the {@code generatedKeyName} is
- * not {@code null}.
+ * see also {@link MySqlOkSegment}.
*/
public final class MySqlResult implements Result {
@@ -155,14 +154,14 @@ public Flux flatMap(Function> f
}
static MySqlResult toResult(boolean binary, Codecs codecs, ConnectionContext context,
- @Nullable String generatedKeyName, Flux messages) {
+ @Nullable String syntheticKeyName, Flux messages) {
requireNonNull(codecs, "codecs must not be null");
requireNonNull(context, "context must not be null");
requireNonNull(messages, "messages must not be null");
return new MySqlResult(OperatorUtils.discardOnCancel(messages)
.doOnDiscard(ReferenceCounted.class, ReferenceCounted::release)
- .handle(new MySqlSegments(binary, codecs, context, generatedKeyName)));
+ .handle(new MySqlSegments(binary, codecs, context, syntheticKeyName)));
}
private static final class MySqlMessage implements Message {
@@ -268,16 +267,16 @@ private static final class MySqlSegments implements BiConsumer sink) {
this.rowMetadata = MySqlRowMetadata.create(metadataMessages);
} else if (message instanceof OkMessage) {
- Segment segment = generatedKeyName == null ? new MySqlUpdateCount((OkMessage) message) :
- new MySqlOkSegment((OkMessage) message, codecs, generatedKeyName);
+ Segment segment = syntheticKeyName == null ? new MySqlUpdateCount((OkMessage) message) :
+ new MySqlOkSegment((OkMessage) message, codecs, syntheticKeyName);
sink.next(segment);
} else if (message instanceof ErrorMessage) {
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java
index 890914bde..0ad86b4db 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java
@@ -16,8 +16,12 @@
package io.asyncer.r2dbc.mysql;
+import io.asyncer.r2dbc.mysql.internal.util.InternalArrays;
+import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import org.jetbrains.annotations.Nullable;
+import java.util.StringJoiner;
+
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty;
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;
@@ -27,26 +31,42 @@
*/
abstract class MySqlStatementSupport implements MySqlStatement {
+ private static final ServerVersion MARIA_10_5_1 = ServerVersion.create(10, 5, 1, true);
+
private static final String LAST_INSERT_ID = "LAST_INSERT_ID";
+ protected final ConnectionContext context;
+
@Nullable
- String generatedKeyName = null;
+ protected String[] generatedColumns = null;
+
+ MySqlStatementSupport(ConnectionContext context) {
+ this.context = requireNonNull(context, "context must not be null");
+ }
@Override
public final MySqlStatement returnGeneratedValues(String... columns) {
requireNonNull(columns, "columns must not be null");
- switch (columns.length) {
- case 0:
- this.generatedKeyName = LAST_INSERT_ID;
- return this;
- case 1:
- requireNonEmpty(columns[0], "id name must not be empty");
- this.generatedKeyName = columns[0];
- return this;
+ int len = columns.length;
+
+ if (len == 0) {
+ this.generatedColumns = InternalArrays.EMPTY_STRINGS;
+ } else if (len == 1 || supportReturning()) {
+ String[] result = new String[len];
+
+ for (int i = 0; i < len; ++i) {
+ requireNonEmpty(columns[i], "returning column must not be empty");
+ result[i] = columns[i];
+ }
+
+ this.generatedColumns = result;
+ } else {
+ String db = context.isMariaDb() ? "MariaDB 10.5.0 or below" : "MySQL";
+ throw new IllegalArgumentException(db + " supports only LAST_INSERT_ID instead of RETURNING");
}
- throw new IllegalArgumentException("MySQL only supports single generated value");
+ return this;
}
@Override
@@ -54,4 +74,44 @@ public MySqlStatement fetchSize(int rows) {
require(rows >= 0, "Fetch size must be greater or equal to zero");
return this;
}
+
+ @Nullable
+ final String syntheticKeyName() {
+ String[] columns = this.generatedColumns;
+
+ // MariaDB should use `RETURNING` clause instead.
+ if (columns == null || supportReturning()) {
+ return null;
+ }
+
+ if (columns.length == 0) {
+ return LAST_INSERT_ID;
+ }
+
+ return columns[0];
+ }
+
+ final String returningIdentifiers() {
+ String[] columns = this.generatedColumns;
+
+ if (columns == null || !supportReturning()) {
+ return "";
+ }
+
+ if (columns.length == 0) {
+ return "*";
+ }
+
+ StringJoiner joiner = new StringJoiner(",");
+
+ for (String column : columns) {
+ joiner.add(StringUtils.quoteIdentifier(column));
+ }
+
+ return joiner.toString();
+ }
+
+ private boolean supportReturning() {
+ return context.isMariaDb() && context.getServerVersion().isGreaterThanOrEqualTo(MARIA_10_5_1);
+ }
}
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java b/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java
index 1fb38ab2d..ffe203077 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java
@@ -45,19 +45,18 @@ abstract class ParametrizedStatementSupport extends MySqlStatementSupport {
protected final Query query;
- protected final ConnectionContext context;
-
private final Bindings bindings;
private final AtomicBoolean executed = new AtomicBoolean();
ParametrizedStatementSupport(Client client, Codecs codecs, Query query, ConnectionContext context) {
+ super(context);
+
requireNonNull(query, "query must not be null");
require(query.getParameters() > 0, "parameters must be a positive integer");
this.client = requireNonNull(client, "client must not be null");
this.codecs = requireNonNull(codecs, "codecs must not be null");
- this.context = requireNonNull(context, "context must not be null");
this.query = query;
this.bindings = new Bindings(query.getParameters());
}
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java b/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java
index 775c535d1..3a946f3ea 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java
@@ -19,6 +19,7 @@
import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
+import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import reactor.core.publisher.Flux;
import java.util.List;
@@ -42,8 +43,11 @@ final class PrepareParametrizedStatement extends ParametrizedStatementSupport {
@Override
public Flux execute(List bindings) {
- return QueryFlow.execute(client, query.getFormattedSql(), bindings, fetchSize, prepareCache)
- .map(messages -> MySqlResult.toResult(true, codecs, context, generatedKeyName, messages));
+ return Flux.defer(() -> QueryFlow.execute(client,
+ StringUtils.extendReturning(query.getFormattedSql(), returningIdentifiers()),
+ bindings, fetchSize, prepareCache
+ ))
+ .map(messages -> MySqlResult.toResult(true, codecs, context, syntheticKeyName(), messages));
}
@Override
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java b/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java
index 07800081b..2284b991e 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java
@@ -19,6 +19,7 @@
import io.asyncer.r2dbc.mysql.cache.PrepareCache;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
+import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import reactor.core.publisher.Flux;
import java.util.Collections;
@@ -45,8 +46,9 @@ final class PrepareSimpleStatement extends SimpleStatementSupport {
@Override
public Flux execute() {
- return QueryFlow.execute(client, sql, BINDINGS, fetchSize, prepareCache)
- .map(messages -> MySqlResult.toResult(true, codecs, context, generatedKeyName, messages));
+ return Flux.defer(() -> QueryFlow.execute(client,
+ StringUtils.extendReturning(sql, returningIdentifiers()), BINDINGS, fetchSize, prepareCache))
+ .map(messages -> MySqlResult.toResult(true, codecs, context, syntheticKeyName(), messages));
}
@Override
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
index 203c8ed5c..7109fd544 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
@@ -104,7 +104,7 @@ final class QueryFlow {
* by {@link CompleteMessage} after receive the last result for the last binding.
*
* @param client the {@link Client} to exchange messages with.
- * @param sql the original statement for exception tracing.
+ * @param sql the statement for exception tracing.
* @param bindings the data of bindings.
* @param fetchSize the size of fetching, if it less than or equal to {@literal 0} means fetch all rows.
* @param cache the cache of server-preparing result.
@@ -129,18 +129,21 @@ static Flux> execute(Client client, String sql, List> execute(Client client, Query query, List bindings) {
+ static Flux> execute(
+ Client client, Query query, String returning, List bindings
+ ) {
return Flux.defer(() -> {
if (bindings.isEmpty()) {
return Flux.empty();
}
- return client.exchange(new TextQueryExchangeable(query, bindings.iterator()))
+ return client.exchange(new TextQueryExchangeable(query, returning, bindings.iterator()))
.windowUntil(RESULT_DONE);
});
}
@@ -195,7 +198,7 @@ static Flux> execute(Client client, List statements)
* @return the messages received in response to the login exchange.
*/
static Mono login(Client client, SslMode sslMode, String database, String user,
- @Nullable CharSequence password, ConnectionContext context) {
+ @Nullable CharSequence password, ConnectionContext context) {
return client.exchange(new LoginExchangeable(client, sslMode, database, user, password, context))
.onErrorResume(e -> client.forceClose().then(Mono.error(e)))
.then(Mono.just(client));
@@ -263,13 +266,14 @@ static Mono beginTransaction(Client client, ConnectionState state, boolean
* Commits or rollbacks current transaction. It will recover statuses of the {@link ConnectionState} in
* the initial connection state.
*
- * @param client the {@link Client} to exchange messages with.
- * @param state the connection state for checks and resets transaction statuses.
- * @param commit if commit, otherwise rollback.
- * @param batchSupported if connection supports batch query.
+ * @param client the {@link Client} to exchange messages with.
+ * @param state the connection state for checks and resets transaction statuses.
+ * @param commit if commit, otherwise rollback.
+ * @param batchSupported if connection supports batch query.
* @return receives complete signal.
*/
- static Mono doneTransaction(Client client, ConnectionState state, boolean commit, boolean batchSupported) {
+ static Mono doneTransaction(Client client, ConnectionState state, boolean commit,
+ boolean batchSupported) {
final CommitRollbackState commitState = new CommitRollbackState(state, commit);
if (batchSupported) {
@@ -279,7 +283,8 @@ static Mono doneTransaction(Client client, ConnectionState state, boolean
return client.exchange(new TransactionMultiExchangeable(commitState)).then();
}
- static Mono createSavepoint(Client client, ConnectionState state, String name, boolean batchSupported) {
+ static Mono createSavepoint(Client client, ConnectionState state, String name,
+ boolean batchSupported) {
final CreateSavepointState savepointState = new CreateSavepointState(state, name);
if (batchSupported) {
return client.exchange(new TransactionBatchExchangeable(savepointState)).then();
@@ -357,10 +362,13 @@ final class TextQueryExchangeable extends BaseFluxExchangeable {
private final Query query;
+ private final String returning;
+
private final Iterator bindings;
- TextQueryExchangeable(Query query, Iterator bindings) {
+ TextQueryExchangeable(Query query, String returning, Iterator bindings) {
this.query = query;
+ this.returning = returning;
this.bindings = bindings;
}
@@ -384,9 +392,9 @@ public boolean isDisposed() {
@Override
protected void tryNextOrComplete(@Nullable SynchronousSink sink) {
if (this.bindings.hasNext()) {
- QueryLogger.log(this.query);
+ QueryLogger.log(this.query, this.returning);
- PreparedTextQueryMessage message = this.bindings.next().toTextMessage(this.query);
+ PreparedTextQueryMessage message = this.bindings.next().toTextMessage(this.query, this.returning);
Sinks.EmitResult result = this.requests.tryEmitNext(message);
if (result == Sinks.EmitResult.OK) {
@@ -404,7 +412,7 @@ protected void tryNextOrComplete(@Nullable SynchronousSink sink)
@Override
protected String offendingSql() {
- return query.getFormattedSql();
+ return StringUtils.extendReturning(query.getFormattedSql(), returning);
}
}
@@ -1153,7 +1161,8 @@ protected boolean process(int task, SynchronousSink sink) {
}
return true;
case ISOLATION_LEVEL:
- final IsolationLevel isolationLevel = definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL);
+ final IsolationLevel isolationLevel =
+ definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL);
if (isolationLevel != null) {
state.setIsolationLevel(isolationLevel);
}
@@ -1254,7 +1263,6 @@ protected boolean process(int task, SynchronousSink sink) {
final class TransactionBatchExchangeable extends FluxExchangeable {
-
private final AbstractTransactionState state;
TransactionBatchExchangeable(AbstractTransactionState state) {
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java b/src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java
index 98f9ad0ab..20c2ff1a9 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java
@@ -17,6 +17,7 @@
package io.asyncer.r2dbc.mysql;
+import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@@ -31,9 +32,10 @@ static void log(String query) {
logger.debug("Executing direct query: {}", query);
}
- static void log(Query query) {
+ static void log(Query query, String returning) {
if (logger.isDebugEnabled()) {
- logger.debug("Executing format query: {}", query.getFormattedSql());
+ logger.debug("Executing format query: {}",
+ StringUtils.extendReturning(query.getFormattedSql(), returning));
}
}
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java b/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java
index daebe5bff..56b34a926 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java
@@ -30,14 +30,13 @@ abstract class SimpleStatementSupport extends MySqlStatementSupport {
protected final Codecs codecs;
- protected final ConnectionContext context;
-
protected final String sql;
SimpleStatementSupport(Client client, Codecs codecs, ConnectionContext context, String sql) {
+ super(context);
+
this.client = requireNonNull(client, "client must not be null");
this.codecs = requireNonNull(codecs, "codecs must not be null");
- this.context = requireNonNull(context, "context must not be null");
this.sql = requireNonNull(sql, "sql must not be null");
}
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java b/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java
index c9e4bedfb..9d2c2c55b 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java
@@ -33,7 +33,7 @@ final class TextParametrizedStatement extends ParametrizedStatementSupport {
@Override
protected Flux execute(List bindings) {
- return QueryFlow.execute(client, query, bindings)
- .map(messages -> MySqlResult.toResult(false, codecs, context, generatedKeyName, messages));
+ return Flux.defer(() -> QueryFlow.execute(client, query, returningIdentifiers(), bindings))
+ .map(messages -> MySqlResult.toResult(false, codecs, context, syntheticKeyName(), messages));
}
}
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java b/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java
index 96c9e5ff1..e23873c07 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java
@@ -18,6 +18,7 @@
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
+import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import reactor.core.publisher.Flux;
/**
@@ -31,7 +32,9 @@ final class TextSimpleStatement extends SimpleStatementSupport {
@Override
public Flux execute() {
- return QueryFlow.execute(client, sql)
- .map(messages -> MySqlResult.toResult(false, codecs, context, generatedKeyName, messages));
+ return Flux.defer(() -> QueryFlow.execute(
+ client,
+ StringUtils.extendReturning(sql, returningIdentifiers()))
+ ).map(messages -> MySqlResult.toResult(false, codecs, context, syntheticKeyName(), messages));
}
}
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java
index 2ccbb55a5..feb5ba3aa 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java
@@ -25,6 +25,12 @@ public final class StringUtils {
private static final char QUOTE = '`';
+ /**
+ * Quotes identifier with backticks, it will escape backticks in the identifier.
+ *
+ * @param identifier the identifier
+ * @return quoted identifier
+ */
public static String quoteIdentifier(String identifier) {
requireNonEmpty(identifier, "identifier must not be empty");
@@ -53,6 +59,17 @@ public static String quoteIdentifier(String identifier) {
return builder.append(QUOTE).toString();
}
+ /**
+ * Extends a SQL statement with {@code RETURNING} clause.
+ *
+ * @param sql the original SQL statement.
+ * @param returning quoted column identifiers.
+ * @return the SQL statement with {@code RETURNING} clause.
+ */
+ public static String extendReturning(String sql, String returning) {
+ return returning.isEmpty() ? sql : sql + " RETURNING " + returning;
+ }
+
private StringUtils() {
}
}
diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/PreparedTextQueryMessage.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/PreparedTextQueryMessage.java
index df5b29df5..d6ea6d783 100644
--- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/PreparedTextQueryMessage.java
+++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/PreparedTextQueryMessage.java
@@ -38,17 +38,21 @@ public final class PreparedTextQueryMessage extends AtomicReference encode(ByteBufAllocator allocator, ConnectionContext contex
return Flux.fromArray(values);
});
- return ParamWriter.publish(query, parameters).map(it -> {
+ return ParamWriter.publish(query, parameters).handle((it, sink) -> {
ByteBuf buf = allocator.buffer();
try {
buf.writeByte(TextQueryMessage.QUERY_FLAG).writeCharSequence(it, charset);
- return buf;
+
+ if (!returning.isEmpty()) {
+ buf.writeCharSequence(" RETURNING ", charset);
+ buf.writeCharSequence(returning, charset);
+ }
+
+ sink.next(buf);
} catch (Throwable e) {
// Maybe IndexOutOfBounds or OOM (too large sql)
buf.release();
- throw e;
+ sink.error(e);
}
});
}
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java
index 2f2bd3186..2c8c907bd 100644
--- a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java
+++ b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java
@@ -68,13 +68,18 @@ void badSetServerZoneId() {
}
public static ConnectionContext mock() {
- return mock(ZoneId.systemDefault());
+ return mock(false, ZoneId.systemDefault());
}
- public static ConnectionContext mock(ZoneId zoneId) {
+ public static ConnectionContext mock(boolean isMariaDB) {
+ return mock(isMariaDB, ZoneId.systemDefault());
+ }
+
+ public static ConnectionContext mock(boolean isMariaDB, ZoneId zoneId) {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, zoneId);
- context.init(1, ServerVersion.parse("8.0.11.MOCKED"), Capability.of(-1));
+ context.init(1, ServerVersion.parse(isMariaDB ? "11.2.22.MOCKED" : "8.0.11.MOCKED"),
+ Capability.of(~(isMariaDB ? 1 : 0)));
context.setServerStatuses(ServerStatuses.AUTO_COMMIT);
return context;
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java
index 11e754e99..c8d887f64 100644
--- a/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java
+++ b/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java
@@ -136,4 +136,16 @@ boolean envIsLessThanMySql57OrMariaDb102() {
return ver.isLessThan(ServerVersion.create(5, 7, 0));
}
+
+ static boolean envIsMariaDb10_5_1() {
+ String type = System.getProperty("test.db.type");
+
+ if (!"mariadb".equalsIgnoreCase(type)) {
+ return false;
+ }
+
+ ServerVersion ver = ServerVersion.parse(System.getProperty("test.mysql.version"));
+
+ return ver.isGreaterThanOrEqualTo(ServerVersion.create(10, 5, 1));
+ }
}
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java
new file mode 100644
index 000000000..610b4ecd2
--- /dev/null
+++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java
@@ -0,0 +1,169 @@
+/*
+ * Copyright 2024 asyncer.io projects
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.asyncer.r2dbc.mysql;
+
+import io.r2dbc.spi.Readable;
+import org.jetbrains.annotations.Nullable;
+import org.junit.jupiter.api.Test;
+
+import java.time.ZonedDateTime;
+import java.util.function.Predicate;
+
+import static java.util.Objects.requireNonNull;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Base class considers integration tests for MariaDB.
+ */
+abstract class MariaDbIntegrationTestSupport extends IntegrationTestSupport {
+
+ MariaDbIntegrationTestSupport(@Nullable Predicate preferPrepared) {
+ super(configuration("r2dbc", false, false, null, preferPrepared));
+ }
+
+ @Test
+ void allReturning() {
+ complete(conn -> conn.createStatement("CREATE TEMPORARY TABLE test (" +
+ "id INT NOT NULL AUTO_INCREMENT PRIMARY KEY," +
+ "value INT NOT NULL," +
+ "created_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3))")
+ .execute()
+ .flatMap(IntegrationTestSupport::extractRowsUpdated)
+ .thenMany(conn.createStatement("INSERT INTO test(value) VALUES (?),(?),(?),(?),(?)")
+ .bind(0, 2)
+ .bind(1, 4)
+ .bind(2, 6)
+ .bind(3, 8)
+ .bind(4, 10)
+ .returnGeneratedValues()
+ .execute())
+ .flatMap(result -> result.map(DataEntity::read))
+ .collectList()
+ .doOnNext(list -> assertThat(list).hasSize(5)
+ .map(DataEntity::getValue)
+ .containsExactly(2, 4, 6, 8, 10))
+ .doOnNext(list -> assertThat(list.stream().map(DataEntity::getId).distinct()).hasSize(5))
+ .doOnNext(list -> assertThat(list.stream().map(DataEntity::getCreatedAt))
+ .noneMatch(it -> it.isBefore(ZonedDateTime.now().minusMinutes(1))))
+ .thenMany(conn.createStatement("REPLACE test(id, value) VALUES (1,?),(2,?),(3,?),(4,?),(5,?)")
+ .bind(0, 3)
+ .bind(1, 5)
+ .bind(2, 7)
+ .bind(3, 9)
+ .bind(4, 11)
+ .returnGeneratedValues()
+ .execute())
+ .flatMap(result -> result.map(DataEntity::read))
+ .collectList()
+ .doOnNext(list -> assertThat(list).hasSize(5)
+ .map(DataEntity::getValue)
+ .containsExactly(3, 5, 7, 9, 11))
+ .doOnNext(list -> assertThat(list.stream().map(DataEntity::getCreatedAt))
+ .noneMatch(it -> it.isBefore(ZonedDateTime.now().minusMinutes(1)))));
+ }
+
+ @Test
+ void partialReturning() {
+ complete(conn -> conn.createStatement("CREATE TEMPORARY TABLE test (" +
+ "id INT NOT NULL AUTO_INCREMENT PRIMARY KEY," +
+ "value INT NOT NULL," +
+ "created_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3))")
+ .execute()
+ .flatMap(IntegrationTestSupport::extractRowsUpdated)
+ .thenMany(conn.createStatement("INSERT INTO test(value) VALUES (?),(?),(?),(?),(?)")
+ .bind(0, 2)
+ .bind(1, 4)
+ .bind(2, 6)
+ .bind(3, 8)
+ .bind(4, 10)
+ .returnGeneratedValues("id", "created_at")
+ .execute())
+ .flatMap(result -> result.map(DataEntity::withoutValue))
+ .collectList()
+ .doOnNext(list -> assertThat(list).hasSize(5)
+ .map(DataEntity::getValue)
+ .containsOnly(0))
+ .doOnNext(list -> assertThat(list.stream().map(DataEntity::getId).distinct()).hasSize(5))
+ .doOnNext(list -> assertThat(list.stream().map(DataEntity::getCreatedAt))
+ .noneMatch(it -> it.isBefore(ZonedDateTime.now().minusMinutes(1))))
+ .thenMany(conn.createStatement("REPLACE test(id, value) VALUES (1,?),(2,?),(3,?),(4,?),(5,?)")
+ .bind(0, 3)
+ .bind(1, 5)
+ .bind(2, 7)
+ .bind(3, 9)
+ .bind(4, 11)
+ .returnGeneratedValues("id", "created_at")
+ .execute())
+ .flatMap(result -> result.map(DataEntity::withoutValue))
+ .collectList()
+ .doOnNext(list -> assertThat(list).hasSize(5)
+ .map(DataEntity::getValue)
+ .containsOnly(0))
+ .doOnNext(list -> assertThat(list.stream().map(DataEntity::getCreatedAt))
+ .noneMatch(it -> it.isBefore(ZonedDateTime.now().minusMinutes(1))))
+ );
+ }
+
+ private static final class DataEntity {
+
+ private final int id;
+
+ private final int value;
+
+ private final ZonedDateTime createdAt;
+
+ private DataEntity(int id, int value, ZonedDateTime createdAt) {
+ this.id = id;
+ this.value = value;
+ this.createdAt = createdAt;
+ }
+
+ int getId() {
+ return id;
+ }
+
+ int getValue() {
+ return value;
+ }
+
+ ZonedDateTime getCreatedAt() {
+ return createdAt;
+ }
+
+ static DataEntity read(Readable readable) {
+ Integer id = readable.get("id", Integer.TYPE);
+ Integer value = readable.get("value", Integer.class);
+ ZonedDateTime createdAt = readable.get("created_at", ZonedDateTime.class);
+
+ requireNonNull(id, "id must not be null");
+ requireNonNull(value, "value must not be null");
+ requireNonNull(createdAt, "createdAt must not be null");
+
+ return new DataEntity(id, value, createdAt);
+ }
+
+ static DataEntity withoutValue(Readable readable) {
+ Integer id = readable.get("id", Integer.TYPE);
+ ZonedDateTime createdAt = readable.get("created_at", ZonedDateTime.class);
+
+ requireNonNull(id, "id must not be null");
+ requireNonNull(createdAt, "createdAt must not be null");
+
+ return new DataEntity(id, 0, createdAt);
+ }
+ }
+}
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java
new file mode 100644
index 000000000..b7ac81a8b
--- /dev/null
+++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2024 asyncer.io projects
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.asyncer.r2dbc.mysql;
+
+import org.junit.jupiter.api.condition.EnabledIf;
+
+/**
+ * Integration tests for MariaDB with server-preparing statements.
+ */
+@EnabledIf("envIsMariaDb10_5_1")
+class MariaDbPrepareIntegrationTest extends MariaDbIntegrationTestSupport {
+
+ MariaDbPrepareIntegrationTest() {
+ super(sql -> true);
+ }
+}
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java
new file mode 100644
index 000000000..0ab886c5f
--- /dev/null
+++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2024 asyncer.io projects
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.asyncer.r2dbc.mysql;
+
+import org.junit.jupiter.api.condition.EnabledIf;
+
+/**
+ * Integration tests for MariaDB with client-preparing statements.
+ */
+@EnabledIf("envIsMariaDb10_5_1")
+class MariaDbTextIntegrationTest extends MariaDbIntegrationTestSupport {
+
+ MariaDbTextIntegrationTest() {
+ super(null);
+ }
+}
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java b/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java
index e638fa38b..74bb3ee92 100644
--- a/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java
+++ b/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java
@@ -32,8 +32,6 @@ class PrepareParametrizedStatementTest implements StatementTestSupport>() {}.getType();
- Type enumSet = new TypeReference>() {}.getType();
+ Type stringSet = new TypeReference>() { }.getType();
+ Type enumSet = new TypeReference>() { }.getType();
testType(String.class, true, "SET('ONE','TWO','THREE')", null, "ONE,TWO,THREE", "ONE", "",
"ONE,THREE");
@@ -268,7 +268,6 @@ void json() {
testType(String.class, false, "JSON", null, "{\"data\": 1}", "[\"data\", 1]", "1", "null",
"\"R2DBC\"", "2.56");
-
}
@Test
@@ -407,7 +406,7 @@ void selectOne() {
@Test
void selectFromOtherDatabase() {
complete(conn -> Flux.from(conn.createStatement("SELECT * FROM `information_schema`.`innodb_trx`")
- .execute())
+ .execute())
.flatMap(result -> result.map((row, metadata) -> row.get(0))));
}
@@ -422,7 +421,7 @@ void multiQueries() {
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.thenMany(Flux.range(0, 10))
.flatMap(it -> Flux.from(connection.createStatement("INSERT INTO test VALUES" +
- "(DEFAULT,?,?,NOW(),NOW())")
+ "(DEFAULT,?,?,NOW(),NOW())")
.bind(0, String.format("integration-test%d@mail.com", it))
.bind(1, "******")
.execute()))
@@ -443,7 +442,7 @@ void multiQueries() {
@Test
void consumePortion() {
complete(connection -> Mono.from(connection.createStatement("CREATE TEMPORARY TABLE test" +
- "(id INT PRIMARY KEY AUTO_INCREMENT,value INT)").execute())
+ "(id INT PRIMARY KEY AUTO_INCREMENT,value INT)").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.then(Mono.from(connection.createStatement("INSERT INTO test(`value`) VALUES (1),(2),(3),(4),(5)")
.execute()))
@@ -453,8 +452,8 @@ void consumePortion() {
.execute()))
.flatMapMany(r -> r.map((row, metadata) -> row.get(0, Integer.TYPE))).take(3)
.concatWith(Mono.from(connection.createStatement("SELECT value FROM test WHERE id > ?")
- .bind(0, 0)
- .execute())
+ .bind(0, 0)
+ .execute())
.flatMapMany(r -> r.map((row, metadata) -> row.get(0, Integer.TYPE))).take(2))
.collectList()
.doOnNext(it -> assertThat(it).isEqualTo(Arrays.asList(1, 2, 3, 1, 2))));
@@ -474,7 +473,7 @@ void ignoreResult() {
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.thenMany(Flux.merge(
Flux.from(connection.createStatement("SELECT value FROM test WHERE id > ?")
- .bind(0, 0).execute())
+ .bind(0, 0).execute())
.flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class)))
.doOnNext(values::add),
connection.createStatement("BAD GRAMMAR").execute()
@@ -496,8 +495,8 @@ void ignoreResult() {
void foundRows() {
int value = 10;
complete(connection -> Flux.from(connection.createStatement("CREATE TEMPORARY TABLE test" +
- "(id INT PRIMARY KEY AUTO_INCREMENT,value INT)")
- .execute())
+ "(id INT PRIMARY KEY AUTO_INCREMENT,value INT)")
+ .execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.thenMany(connection.createStatement("INSERT INTO test VALUES(DEFAULT,?)")
.bind(0, value).execute())
@@ -522,11 +521,11 @@ void foundRows() {
@Test
void insertOnDuplicate() {
complete(connection -> Flux.from(connection.createStatement("CREATE TEMPORARY TABLE test" +
- "(id INT PRIMARY KEY,value INT)")
- .execute())
+ "(id INT PRIMARY KEY,value INT)")
+ .execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.thenMany(connection.createStatement("INSERT INTO test VALUES(?,?) " +
- "ON DUPLICATE KEY UPDATE value=?")
+ "ON DUPLICATE KEY UPDATE value=?")
.bind(0, 1)
.bind(1, 10)
.bind(2, 20)
@@ -540,7 +539,7 @@ void insertOnDuplicate() {
.collectList()
.doOnNext(it -> assertThat(it).isEqualTo(Collections.singletonList(10)))
.thenMany(connection.createStatement("INSERT INTO test VALUES(?,?) " +
- "ON DUPLICATE KEY UPDATE value=?")
+ "ON DUPLICATE KEY UPDATE value=?")
.bind(0, 1)
.bind(1, 10)
.bind(2, 20)
@@ -554,7 +553,7 @@ void insertOnDuplicate() {
.collectList()
.doOnNext(it -> assertThat(it).isEqualTo(Collections.singletonList(20)))
.thenMany(connection.createStatement("INSERT INTO test VALUES(?,?) " +
- "ON DUPLICATE KEY UPDATE value=?")
+ "ON DUPLICATE KEY UPDATE value=?")
.bind(0, 1)
.bind(1, 10)
.bind(2, 20)
@@ -625,22 +624,6 @@ private static Flux extractFirstInteger(Result result) {
return Flux.from(result.map((row, metadata) -> row.get(0, Integer.class)));
}
- private static Flux> extractOk(Result result, Class type) {
- return Flux.from(result.flatMap(segment -> {
- try {
- if (segment instanceof Result.UpdateCount && segment instanceof Result.RowSegment) {
- long affected = ((Result.UpdateCount) segment).value();
- T t = Objects.requireNonNull(((Result.RowSegment) segment).row().get(0, type));
- return Mono.just(Tuples.of(affected, t));
- } else {
- return Mono.empty();
- }
- } finally {
- ReferenceCountUtil.release(segment);
- }
- }));
- }
-
@SuppressWarnings("unchecked")
private static Flux> extractOptionalField(Result result, Type type) {
if (type instanceof Class>) {
@@ -652,9 +635,9 @@ private static Flux> extractOptionalField(Result result, Type ty
private static Mono testTimeDuration(Connection connection, Duration origin, LocalTime time) {
return Mono.from(connection.createStatement("INSERT INTO test VALUES(DEFAULT,?)")
- .bind(0, origin)
- .returnGeneratedValues("id")
- .execute())
+ .bind(0, origin)
+ .returnGeneratedValues("id")
+ .execute())
.flatMapMany(QueryIntegrationTestSupport::extractFirstInteger)
.concatMap(id -> connection.createStatement("SELECT value FROM test WHERE id=?")
.bind(0, id)
@@ -690,14 +673,13 @@ private static Mono testOne(MySqlConnection connection, Type type, boo
}
}
- return Mono.from(insert.returnGeneratedValues("id")
- .execute())
- .flatMap(result -> extractOk(result, Integer.class)
- .collectList()
- .map(ids -> {
- assertThat(ids).hasSize(1).first().extracting(Tuple2::getT1).isEqualTo(1L);
- return ids.get(0).getT2();
- }))
+ return Mono.from(insert.returnGeneratedValues("id").execute())
+ .flatMapMany(QueryIntegrationTestSupport::extractFirstInteger)
+ .collectList()
+ .map(ids -> {
+ assertThat(ids).hasSize(1).first().isNotNull();
+ return ids.get(0);
+ })
.flatMap(id -> Mono.from(connection.createStatement("SELECT value FROM test WHERE id=?")
.bind(0, id)
.execute()))
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java
index 65300af67..28f159c37 100644
--- a/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java
+++ b/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java
@@ -17,9 +17,12 @@
package io.asyncer.r2dbc.mysql;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.provider.ValueSource;
import java.util.NoSuchElementException;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -33,7 +36,7 @@ interface StatementTestSupport {
String SIMPLE = "SELECT * FROM test WHERE id = 1 AND name = 'Mirrors'";
- T makeInstance(String parametrizedSql, String simpleSql);
+ T makeInstance(boolean isMariaDB, String parametrizedSql, String simpleSql);
boolean supportsBinding();
@@ -45,7 +48,7 @@ default int getFetchSize(T statement) throws IllegalAccessException {
default void bind() {
assertTrue(supportsBinding(), "Must skip test case #bind() for simple statements");
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
+ T statement = makeInstance(false, PARAMETRIZED, SIMPLE);
statement.bind(0, 1);
statement.bind("id", 1);
statement.bind(1, 1);
@@ -54,7 +57,7 @@ default void bind() {
@SuppressWarnings("ConstantConditions")
@Test
default void badBind() {
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
+ T statement = makeInstance(false, PARAMETRIZED, SIMPLE);
if (supportsBinding()) {
assertThrows(IllegalArgumentException.class, () -> statement.bind(0, null));
@@ -86,7 +89,7 @@ default void badBind() {
default void bindNull() {
assertTrue(supportsBinding(), "Must skip test case #bindNull() for simple statements");
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
+ T statement = makeInstance(false, PARAMETRIZED, SIMPLE);
statement.bindNull(0, Integer.class);
statement.bindNull("id", Integer.class);
statement.bindNull(1, Integer.class);
@@ -95,7 +98,7 @@ default void bindNull() {
@SuppressWarnings("ConstantConditions")
@Test
default void badBindNull() {
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
+ T statement = makeInstance(false, PARAMETRIZED, SIMPLE);
if (supportsBinding()) {
assertThrows(IllegalArgumentException.class, () -> statement.bindNull(0, null));
@@ -125,7 +128,7 @@ default void badBindNull() {
@Test
default void add() {
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
+ T statement = makeInstance(false, PARAMETRIZED, SIMPLE);
if (!supportsBinding()) {
statement.add();
@@ -143,38 +146,91 @@ default void add() {
default void badAdd() {
assertTrue(supportsBinding(), "Must skip test case #badAdd() for simple statements");
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
+ T statement = makeInstance(false, PARAMETRIZED, SIMPLE);
statement.bind(0, 1);
assertThrows(IllegalStateException.class, statement::add);
}
@Test
- default void returnGeneratedValues() {
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
-
- statement.returnGeneratedValues();
- assertEquals(statement.generatedKeyName, "LAST_INSERT_ID");
- statement.returnGeneratedValues("generated");
- assertEquals(statement.generatedKeyName, "generated");
- statement.returnGeneratedValues("generate`d");
- assertEquals(statement.generatedKeyName, "generate`d");
+ default void mySqlReturnGeneratedValues() {
+ T s = makeInstance(false, PARAMETRIZED, SIMPLE);
+
+ s.returnGeneratedValues();
+
+ assertThat(s.syntheticKeyName()).isEqualTo("LAST_INSERT_ID");
+ assertThat(s.returningIdentifiers()).isEqualTo("");
+
+ s.returnGeneratedValues("generated");
+
+ assertThat(s.syntheticKeyName()).isEqualTo("generated");
+ assertThat(s.returningIdentifiers()).isEqualTo("");
+
+ s.returnGeneratedValues("generate`d");
+
+ assertThat(s.syntheticKeyName()).isEqualTo("generate`d");
+ assertThat(s.returningIdentifiers()).isEqualTo("");
+ }
+
+ @Test
+ default void mariaDbReturnGeneratedValues() {
+ T s = makeInstance(true, PARAMETRIZED, SIMPLE);
+
+ s.returnGeneratedValues();
+
+ assertThat(s.syntheticKeyName()).isNull();
+ assertThat(s.returningIdentifiers()).isEqualTo("*");
+
+ s.returnGeneratedValues("generated");
+
+ assertThat(s.syntheticKeyName()).isNull();
+ assertThat(s.returningIdentifiers()).isEqualTo("`generated`");
+
+ s.returnGeneratedValues("generate`d");
+
+ assertThat(s.syntheticKeyName()).isNull();
+ assertThat(s.returningIdentifiers()).isEqualTo("`generate``d`");
+
+ s.returnGeneratedValues("id", "name");
+
+ assertThat(s.syntheticKeyName()).isNull();
+ assertThat(s.returningIdentifiers()).isEqualTo("`id`,`name`");
+
+ s.returnGeneratedValues("id", "name", "desc", "created_at");
+
+ assertThat(s.syntheticKeyName()).isNull();
+ assertThat(s.returningIdentifiers()).isEqualTo("`id`,`name`,`desc`,`created_at`");
}
@SuppressWarnings("ConstantConditions")
@Test
default void badReturnGeneratedValues() {
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
+ T s = makeInstance(false, PARAMETRIZED, SIMPLE);
+
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues((String) null));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues((String[]) null));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues(""));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("", ""));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", "name"));
+ }
- assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String) null));
- assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String[]) null));
- assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues(""));
- assertThrows(IllegalArgumentException.class, () ->
- statement.returnGeneratedValues("generated", "names"));
+ @SuppressWarnings("ConstantConditions")
+ @Test
+ default void mariaDbBadReturnGeneratedValues() {
+ T s = makeInstance(true, PARAMETRIZED, SIMPLE);
+
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues((String) null));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues((String[]) null));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues(""));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("", ""));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", ""));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", null));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", "", "name"));
+ assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", null, "name"));
}
@Test
default void fetchSize() throws IllegalAccessException {
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
+ T statement = makeInstance(false, PARAMETRIZED, SIMPLE);
assertEquals(0, getFetchSize(statement), "Must skip test case #fetchSize() for text-based queries");
for (int i = 1; i <= 10; ++i) {
@@ -192,7 +248,7 @@ default void fetchSize() throws IllegalAccessException {
@Test
default void badFetchSize() {
- T statement = makeInstance(PARAMETRIZED, SIMPLE);
+ T statement = makeInstance(false, PARAMETRIZED, SIMPLE);
assertThrows(IllegalArgumentException.class, () -> statement.fetchSize(-1));
assertThrows(IllegalArgumentException.class, () -> statement.fetchSize(-10));
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java b/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java
index 5ad9ab2c3..38646ec3f 100644
--- a/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java
+++ b/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java
@@ -29,8 +29,6 @@ class TextParametrizedStatementTest implements StatementTestSupport implements CodecTest
@Override
public CodecContext context() {
- return ConnectionContextTest.mock(ENCODE_SERVER_ZONE);
+ return ConnectionContextTest.mock(false, ENCODE_SERVER_ZONE);
}
protected final String toText(Temporal dateTime) {
diff --git a/src/test/java/io/asyncer/r2dbc/mysql/codec/TimeCodecTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/codec/TimeCodecTestSupport.java
index ff0a572b3..88ec88244 100644
--- a/src/test/java/io/asyncer/r2dbc/mysql/codec/TimeCodecTestSupport.java
+++ b/src/test/java/io/asyncer/r2dbc/mysql/codec/TimeCodecTestSupport.java
@@ -55,7 +55,7 @@ abstract class TimeCodecTestSupport implements CodecTestSupp
@Override
public CodecContext context() {
- return ConnectionContextTest.mock(ENCODE_SERVER_ZONE);
+ return ConnectionContextTest.mock(false, ENCODE_SERVER_ZONE);
}
protected final String toText(Temporal time) {