diff --git a/src/main/java/io/asyncer/r2dbc/mysql/Binding.java b/src/main/java/io/asyncer/r2dbc/mysql/Binding.java index 20fafe5ea..8dd64d7ec 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/Binding.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/Binding.java @@ -70,8 +70,8 @@ PreparedExecuteMessage toExecuteMessage(int statementId, boolean immediate) { return new PreparedExecuteMessage(statementId, immediate, drainValues()); } - PreparedTextQueryMessage toTextMessage(Query query) { - return new PreparedTextQueryMessage(query, drainValues()); + PreparedTextQueryMessage toTextMessage(Query query, String returning) { + return new PreparedTextQueryMessage(query, returning, drainValues()); } /** diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java index 2b09277f5..943987d1e 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java @@ -51,8 +51,7 @@ * An implementation of {@link Result} representing the results of a query against the MySQL database. *

* A {@link Segment} provided by this implementation may be both {@link UpdateCount} and {@link RowSegment}, - * see also {@link MySqlOkSegment}. It's based on a {@link OkMessage}, when the {@code generatedKeyName} is - * not {@code null}. + * see also {@link MySqlOkSegment}. */ public final class MySqlResult implements Result { @@ -155,14 +154,14 @@ public Flux flatMap(Function> f } static MySqlResult toResult(boolean binary, Codecs codecs, ConnectionContext context, - @Nullable String generatedKeyName, Flux messages) { + @Nullable String syntheticKeyName, Flux messages) { requireNonNull(codecs, "codecs must not be null"); requireNonNull(context, "context must not be null"); requireNonNull(messages, "messages must not be null"); return new MySqlResult(OperatorUtils.discardOnCancel(messages) .doOnDiscard(ReferenceCounted.class, ReferenceCounted::release) - .handle(new MySqlSegments(binary, codecs, context, generatedKeyName))); + .handle(new MySqlSegments(binary, codecs, context, syntheticKeyName))); } private static final class MySqlMessage implements Message { @@ -268,16 +267,16 @@ private static final class MySqlSegments implements BiConsumer sink) { this.rowMetadata = MySqlRowMetadata.create(metadataMessages); } else if (message instanceof OkMessage) { - Segment segment = generatedKeyName == null ? new MySqlUpdateCount((OkMessage) message) : - new MySqlOkSegment((OkMessage) message, codecs, generatedKeyName); + Segment segment = syntheticKeyName == null ? new MySqlUpdateCount((OkMessage) message) : + new MySqlOkSegment((OkMessage) message, codecs, syntheticKeyName); sink.next(segment); } else if (message instanceof ErrorMessage) { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java index 890914bde..0ad86b4db 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java @@ -16,8 +16,12 @@ package io.asyncer.r2dbc.mysql; +import io.asyncer.r2dbc.mysql.internal.util.InternalArrays; +import io.asyncer.r2dbc.mysql.internal.util.StringUtils; import org.jetbrains.annotations.Nullable; +import java.util.StringJoiner; + import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonEmpty; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; @@ -27,26 +31,42 @@ */ abstract class MySqlStatementSupport implements MySqlStatement { + private static final ServerVersion MARIA_10_5_1 = ServerVersion.create(10, 5, 1, true); + private static final String LAST_INSERT_ID = "LAST_INSERT_ID"; + protected final ConnectionContext context; + @Nullable - String generatedKeyName = null; + protected String[] generatedColumns = null; + + MySqlStatementSupport(ConnectionContext context) { + this.context = requireNonNull(context, "context must not be null"); + } @Override public final MySqlStatement returnGeneratedValues(String... columns) { requireNonNull(columns, "columns must not be null"); - switch (columns.length) { - case 0: - this.generatedKeyName = LAST_INSERT_ID; - return this; - case 1: - requireNonEmpty(columns[0], "id name must not be empty"); - this.generatedKeyName = columns[0]; - return this; + int len = columns.length; + + if (len == 0) { + this.generatedColumns = InternalArrays.EMPTY_STRINGS; + } else if (len == 1 || supportReturning()) { + String[] result = new String[len]; + + for (int i = 0; i < len; ++i) { + requireNonEmpty(columns[i], "returning column must not be empty"); + result[i] = columns[i]; + } + + this.generatedColumns = result; + } else { + String db = context.isMariaDb() ? "MariaDB 10.5.0 or below" : "MySQL"; + throw new IllegalArgumentException(db + " supports only LAST_INSERT_ID instead of RETURNING"); } - throw new IllegalArgumentException("MySQL only supports single generated value"); + return this; } @Override @@ -54,4 +74,44 @@ public MySqlStatement fetchSize(int rows) { require(rows >= 0, "Fetch size must be greater or equal to zero"); return this; } + + @Nullable + final String syntheticKeyName() { + String[] columns = this.generatedColumns; + + // MariaDB should use `RETURNING` clause instead. + if (columns == null || supportReturning()) { + return null; + } + + if (columns.length == 0) { + return LAST_INSERT_ID; + } + + return columns[0]; + } + + final String returningIdentifiers() { + String[] columns = this.generatedColumns; + + if (columns == null || !supportReturning()) { + return ""; + } + + if (columns.length == 0) { + return "*"; + } + + StringJoiner joiner = new StringJoiner(","); + + for (String column : columns) { + joiner.add(StringUtils.quoteIdentifier(column)); + } + + return joiner.toString(); + } + + private boolean supportReturning() { + return context.isMariaDb() && context.getServerVersion().isGreaterThanOrEqualTo(MARIA_10_5_1); + } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java b/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java index 1fb38ab2d..ffe203077 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/ParametrizedStatementSupport.java @@ -45,19 +45,18 @@ abstract class ParametrizedStatementSupport extends MySqlStatementSupport { protected final Query query; - protected final ConnectionContext context; - private final Bindings bindings; private final AtomicBoolean executed = new AtomicBoolean(); ParametrizedStatementSupport(Client client, Codecs codecs, Query query, ConnectionContext context) { + super(context); + requireNonNull(query, "query must not be null"); require(query.getParameters() > 0, "parameters must be a positive integer"); this.client = requireNonNull(client, "client must not be null"); this.codecs = requireNonNull(codecs, "codecs must not be null"); - this.context = requireNonNull(context, "context must not be null"); this.query = query; this.bindings = new Bindings(query.getParameters()); } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java b/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java index 775c535d1..3a946f3ea 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatement.java @@ -19,6 +19,7 @@ import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.codec.Codecs; +import io.asyncer.r2dbc.mysql.internal.util.StringUtils; import reactor.core.publisher.Flux; import java.util.List; @@ -42,8 +43,11 @@ final class PrepareParametrizedStatement extends ParametrizedStatementSupport { @Override public Flux execute(List bindings) { - return QueryFlow.execute(client, query.getFormattedSql(), bindings, fetchSize, prepareCache) - .map(messages -> MySqlResult.toResult(true, codecs, context, generatedKeyName, messages)); + return Flux.defer(() -> QueryFlow.execute(client, + StringUtils.extendReturning(query.getFormattedSql(), returningIdentifiers()), + bindings, fetchSize, prepareCache + )) + .map(messages -> MySqlResult.toResult(true, codecs, context, syntheticKeyName(), messages)); } @Override diff --git a/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java b/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java index 07800081b..2284b991e 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/PrepareSimpleStatement.java @@ -19,6 +19,7 @@ import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.codec.Codecs; +import io.asyncer.r2dbc.mysql.internal.util.StringUtils; import reactor.core.publisher.Flux; import java.util.Collections; @@ -45,8 +46,9 @@ final class PrepareSimpleStatement extends SimpleStatementSupport { @Override public Flux execute() { - return QueryFlow.execute(client, sql, BINDINGS, fetchSize, prepareCache) - .map(messages -> MySqlResult.toResult(true, codecs, context, generatedKeyName, messages)); + return Flux.defer(() -> QueryFlow.execute(client, + StringUtils.extendReturning(sql, returningIdentifiers()), BINDINGS, fetchSize, prepareCache)) + .map(messages -> MySqlResult.toResult(true, codecs, context, syntheticKeyName(), messages)); } @Override diff --git a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java index 203c8ed5c..7109fd544 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java @@ -104,7 +104,7 @@ final class QueryFlow { * by {@link CompleteMessage} after receive the last result for the last binding. * * @param client the {@link Client} to exchange messages with. - * @param sql the original statement for exception tracing. + * @param sql the statement for exception tracing. * @param bindings the data of bindings. * @param fetchSize the size of fetching, if it less than or equal to {@literal 0} means fetch all rows. * @param cache the cache of server-preparing result. @@ -129,18 +129,21 @@ static Flux> execute(Client client, String sql, List> execute(Client client, Query query, List bindings) { + static Flux> execute( + Client client, Query query, String returning, List bindings + ) { return Flux.defer(() -> { if (bindings.isEmpty()) { return Flux.empty(); } - return client.exchange(new TextQueryExchangeable(query, bindings.iterator())) + return client.exchange(new TextQueryExchangeable(query, returning, bindings.iterator())) .windowUntil(RESULT_DONE); }); } @@ -195,7 +198,7 @@ static Flux> execute(Client client, List statements) * @return the messages received in response to the login exchange. */ static Mono login(Client client, SslMode sslMode, String database, String user, - @Nullable CharSequence password, ConnectionContext context) { + @Nullable CharSequence password, ConnectionContext context) { return client.exchange(new LoginExchangeable(client, sslMode, database, user, password, context)) .onErrorResume(e -> client.forceClose().then(Mono.error(e))) .then(Mono.just(client)); @@ -263,13 +266,14 @@ static Mono beginTransaction(Client client, ConnectionState state, boolean * Commits or rollbacks current transaction. It will recover statuses of the {@link ConnectionState} in * the initial connection state. * - * @param client the {@link Client} to exchange messages with. - * @param state the connection state for checks and resets transaction statuses. - * @param commit if commit, otherwise rollback. - * @param batchSupported if connection supports batch query. + * @param client the {@link Client} to exchange messages with. + * @param state the connection state for checks and resets transaction statuses. + * @param commit if commit, otherwise rollback. + * @param batchSupported if connection supports batch query. * @return receives complete signal. */ - static Mono doneTransaction(Client client, ConnectionState state, boolean commit, boolean batchSupported) { + static Mono doneTransaction(Client client, ConnectionState state, boolean commit, + boolean batchSupported) { final CommitRollbackState commitState = new CommitRollbackState(state, commit); if (batchSupported) { @@ -279,7 +283,8 @@ static Mono doneTransaction(Client client, ConnectionState state, boolean return client.exchange(new TransactionMultiExchangeable(commitState)).then(); } - static Mono createSavepoint(Client client, ConnectionState state, String name, boolean batchSupported) { + static Mono createSavepoint(Client client, ConnectionState state, String name, + boolean batchSupported) { final CreateSavepointState savepointState = new CreateSavepointState(state, name); if (batchSupported) { return client.exchange(new TransactionBatchExchangeable(savepointState)).then(); @@ -357,10 +362,13 @@ final class TextQueryExchangeable extends BaseFluxExchangeable { private final Query query; + private final String returning; + private final Iterator bindings; - TextQueryExchangeable(Query query, Iterator bindings) { + TextQueryExchangeable(Query query, String returning, Iterator bindings) { this.query = query; + this.returning = returning; this.bindings = bindings; } @@ -384,9 +392,9 @@ public boolean isDisposed() { @Override protected void tryNextOrComplete(@Nullable SynchronousSink sink) { if (this.bindings.hasNext()) { - QueryLogger.log(this.query); + QueryLogger.log(this.query, this.returning); - PreparedTextQueryMessage message = this.bindings.next().toTextMessage(this.query); + PreparedTextQueryMessage message = this.bindings.next().toTextMessage(this.query, this.returning); Sinks.EmitResult result = this.requests.tryEmitNext(message); if (result == Sinks.EmitResult.OK) { @@ -404,7 +412,7 @@ protected void tryNextOrComplete(@Nullable SynchronousSink sink) @Override protected String offendingSql() { - return query.getFormattedSql(); + return StringUtils.extendReturning(query.getFormattedSql(), returning); } } @@ -1153,7 +1161,8 @@ protected boolean process(int task, SynchronousSink sink) { } return true; case ISOLATION_LEVEL: - final IsolationLevel isolationLevel = definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL); + final IsolationLevel isolationLevel = + definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL); if (isolationLevel != null) { state.setIsolationLevel(isolationLevel); } @@ -1254,7 +1263,6 @@ protected boolean process(int task, SynchronousSink sink) { final class TransactionBatchExchangeable extends FluxExchangeable { - private final AbstractTransactionState state; TransactionBatchExchangeable(AbstractTransactionState state) { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java b/src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java index 98f9ad0ab..20c2ff1a9 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java @@ -17,6 +17,7 @@ package io.asyncer.r2dbc.mysql; +import io.asyncer.r2dbc.mysql.internal.util.StringUtils; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -31,9 +32,10 @@ static void log(String query) { logger.debug("Executing direct query: {}", query); } - static void log(Query query) { + static void log(Query query, String returning) { if (logger.isDebugEnabled()) { - logger.debug("Executing format query: {}", query.getFormattedSql()); + logger.debug("Executing format query: {}", + StringUtils.extendReturning(query.getFormattedSql(), returning)); } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java b/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java index daebe5bff..56b34a926 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/SimpleStatementSupport.java @@ -30,14 +30,13 @@ abstract class SimpleStatementSupport extends MySqlStatementSupport { protected final Codecs codecs; - protected final ConnectionContext context; - protected final String sql; SimpleStatementSupport(Client client, Codecs codecs, ConnectionContext context, String sql) { + super(context); + this.client = requireNonNull(client, "client must not be null"); this.codecs = requireNonNull(codecs, "codecs must not be null"); - this.context = requireNonNull(context, "context must not be null"); this.sql = requireNonNull(sql, "sql must not be null"); } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java b/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java index c9e4bedfb..9d2c2c55b 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/TextParametrizedStatement.java @@ -33,7 +33,7 @@ final class TextParametrizedStatement extends ParametrizedStatementSupport { @Override protected Flux execute(List bindings) { - return QueryFlow.execute(client, query, bindings) - .map(messages -> MySqlResult.toResult(false, codecs, context, generatedKeyName, messages)); + return Flux.defer(() -> QueryFlow.execute(client, query, returningIdentifiers(), bindings)) + .map(messages -> MySqlResult.toResult(false, codecs, context, syntheticKeyName(), messages)); } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java b/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java index 96c9e5ff1..e23873c07 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java @@ -18,6 +18,7 @@ import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.codec.Codecs; +import io.asyncer.r2dbc.mysql.internal.util.StringUtils; import reactor.core.publisher.Flux; /** @@ -31,7 +32,9 @@ final class TextSimpleStatement extends SimpleStatementSupport { @Override public Flux execute() { - return QueryFlow.execute(client, sql) - .map(messages -> MySqlResult.toResult(false, codecs, context, generatedKeyName, messages)); + return Flux.defer(() -> QueryFlow.execute( + client, + StringUtils.extendReturning(sql, returningIdentifiers())) + ).map(messages -> MySqlResult.toResult(false, codecs, context, syntheticKeyName(), messages)); } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java index 2ccbb55a5..feb5ba3aa 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/StringUtils.java @@ -25,6 +25,12 @@ public final class StringUtils { private static final char QUOTE = '`'; + /** + * Quotes identifier with backticks, it will escape backticks in the identifier. + * + * @param identifier the identifier + * @return quoted identifier + */ public static String quoteIdentifier(String identifier) { requireNonEmpty(identifier, "identifier must not be empty"); @@ -53,6 +59,17 @@ public static String quoteIdentifier(String identifier) { return builder.append(QUOTE).toString(); } + /** + * Extends a SQL statement with {@code RETURNING} clause. + * + * @param sql the original SQL statement. + * @param returning quoted column identifiers. + * @return the SQL statement with {@code RETURNING} clause. + */ + public static String extendReturning(String sql, String returning) { + return returning.isEmpty() ? sql : sql + " RETURNING " + returning; + } + private StringUtils() { } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/PreparedTextQueryMessage.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/PreparedTextQueryMessage.java index df5b29df5..d6ea6d783 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/PreparedTextQueryMessage.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/PreparedTextQueryMessage.java @@ -38,17 +38,21 @@ public final class PreparedTextQueryMessage extends AtomicReference encode(ByteBufAllocator allocator, ConnectionContext contex return Flux.fromArray(values); }); - return ParamWriter.publish(query, parameters).map(it -> { + return ParamWriter.publish(query, parameters).handle((it, sink) -> { ByteBuf buf = allocator.buffer(); try { buf.writeByte(TextQueryMessage.QUERY_FLAG).writeCharSequence(it, charset); - return buf; + + if (!returning.isEmpty()) { + buf.writeCharSequence(" RETURNING ", charset); + buf.writeCharSequence(returning, charset); + } + + sink.next(buf); } catch (Throwable e) { // Maybe IndexOutOfBounds or OOM (too large sql) buf.release(); - throw e; + sink.error(e); } }); } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java index 2f2bd3186..2c8c907bd 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionContextTest.java @@ -68,13 +68,18 @@ void badSetServerZoneId() { } public static ConnectionContext mock() { - return mock(ZoneId.systemDefault()); + return mock(false, ZoneId.systemDefault()); } - public static ConnectionContext mock(ZoneId zoneId) { + public static ConnectionContext mock(boolean isMariaDB) { + return mock(isMariaDB, ZoneId.systemDefault()); + } + + public static ConnectionContext mock(boolean isMariaDB, ZoneId zoneId) { ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, zoneId); - context.init(1, ServerVersion.parse("8.0.11.MOCKED"), Capability.of(-1)); + context.init(1, ServerVersion.parse(isMariaDB ? "11.2.22.MOCKED" : "8.0.11.MOCKED"), + Capability.of(~(isMariaDB ? 1 : 0))); context.setServerStatuses(ServerStatuses.AUTO_COMMIT); return context; diff --git a/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java index 11e754e99..c8d887f64 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java @@ -136,4 +136,16 @@ boolean envIsLessThanMySql57OrMariaDb102() { return ver.isLessThan(ServerVersion.create(5, 7, 0)); } + + static boolean envIsMariaDb10_5_1() { + String type = System.getProperty("test.db.type"); + + if (!"mariadb".equalsIgnoreCase(type)) { + return false; + } + + ServerVersion ver = ServerVersion.parse(System.getProperty("test.mysql.version")); + + return ver.isGreaterThanOrEqualTo(ServerVersion.create(10, 5, 1)); + } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java new file mode 100644 index 000000000..610b4ecd2 --- /dev/null +++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java @@ -0,0 +1,169 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.r2dbc.spi.Readable; +import org.jetbrains.annotations.Nullable; +import org.junit.jupiter.api.Test; + +import java.time.ZonedDateTime; +import java.util.function.Predicate; + +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Base class considers integration tests for MariaDB. + */ +abstract class MariaDbIntegrationTestSupport extends IntegrationTestSupport { + + MariaDbIntegrationTestSupport(@Nullable Predicate preferPrepared) { + super(configuration("r2dbc", false, false, null, preferPrepared)); + } + + @Test + void allReturning() { + complete(conn -> conn.createStatement("CREATE TEMPORARY TABLE test (" + + "id INT NOT NULL AUTO_INCREMENT PRIMARY KEY," + + "value INT NOT NULL," + + "created_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3))") + .execute() + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .thenMany(conn.createStatement("INSERT INTO test(value) VALUES (?),(?),(?),(?),(?)") + .bind(0, 2) + .bind(1, 4) + .bind(2, 6) + .bind(3, 8) + .bind(4, 10) + .returnGeneratedValues() + .execute()) + .flatMap(result -> result.map(DataEntity::read)) + .collectList() + .doOnNext(list -> assertThat(list).hasSize(5) + .map(DataEntity::getValue) + .containsExactly(2, 4, 6, 8, 10)) + .doOnNext(list -> assertThat(list.stream().map(DataEntity::getId).distinct()).hasSize(5)) + .doOnNext(list -> assertThat(list.stream().map(DataEntity::getCreatedAt)) + .noneMatch(it -> it.isBefore(ZonedDateTime.now().minusMinutes(1)))) + .thenMany(conn.createStatement("REPLACE test(id, value) VALUES (1,?),(2,?),(3,?),(4,?),(5,?)") + .bind(0, 3) + .bind(1, 5) + .bind(2, 7) + .bind(3, 9) + .bind(4, 11) + .returnGeneratedValues() + .execute()) + .flatMap(result -> result.map(DataEntity::read)) + .collectList() + .doOnNext(list -> assertThat(list).hasSize(5) + .map(DataEntity::getValue) + .containsExactly(3, 5, 7, 9, 11)) + .doOnNext(list -> assertThat(list.stream().map(DataEntity::getCreatedAt)) + .noneMatch(it -> it.isBefore(ZonedDateTime.now().minusMinutes(1))))); + } + + @Test + void partialReturning() { + complete(conn -> conn.createStatement("CREATE TEMPORARY TABLE test (" + + "id INT NOT NULL AUTO_INCREMENT PRIMARY KEY," + + "value INT NOT NULL," + + "created_at DATETIME(3) NOT NULL DEFAULT CURRENT_TIMESTAMP(3))") + .execute() + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .thenMany(conn.createStatement("INSERT INTO test(value) VALUES (?),(?),(?),(?),(?)") + .bind(0, 2) + .bind(1, 4) + .bind(2, 6) + .bind(3, 8) + .bind(4, 10) + .returnGeneratedValues("id", "created_at") + .execute()) + .flatMap(result -> result.map(DataEntity::withoutValue)) + .collectList() + .doOnNext(list -> assertThat(list).hasSize(5) + .map(DataEntity::getValue) + .containsOnly(0)) + .doOnNext(list -> assertThat(list.stream().map(DataEntity::getId).distinct()).hasSize(5)) + .doOnNext(list -> assertThat(list.stream().map(DataEntity::getCreatedAt)) + .noneMatch(it -> it.isBefore(ZonedDateTime.now().minusMinutes(1)))) + .thenMany(conn.createStatement("REPLACE test(id, value) VALUES (1,?),(2,?),(3,?),(4,?),(5,?)") + .bind(0, 3) + .bind(1, 5) + .bind(2, 7) + .bind(3, 9) + .bind(4, 11) + .returnGeneratedValues("id", "created_at") + .execute()) + .flatMap(result -> result.map(DataEntity::withoutValue)) + .collectList() + .doOnNext(list -> assertThat(list).hasSize(5) + .map(DataEntity::getValue) + .containsOnly(0)) + .doOnNext(list -> assertThat(list.stream().map(DataEntity::getCreatedAt)) + .noneMatch(it -> it.isBefore(ZonedDateTime.now().minusMinutes(1)))) + ); + } + + private static final class DataEntity { + + private final int id; + + private final int value; + + private final ZonedDateTime createdAt; + + private DataEntity(int id, int value, ZonedDateTime createdAt) { + this.id = id; + this.value = value; + this.createdAt = createdAt; + } + + int getId() { + return id; + } + + int getValue() { + return value; + } + + ZonedDateTime getCreatedAt() { + return createdAt; + } + + static DataEntity read(Readable readable) { + Integer id = readable.get("id", Integer.TYPE); + Integer value = readable.get("value", Integer.class); + ZonedDateTime createdAt = readable.get("created_at", ZonedDateTime.class); + + requireNonNull(id, "id must not be null"); + requireNonNull(value, "value must not be null"); + requireNonNull(createdAt, "createdAt must not be null"); + + return new DataEntity(id, value, createdAt); + } + + static DataEntity withoutValue(Readable readable) { + Integer id = readable.get("id", Integer.TYPE); + ZonedDateTime createdAt = readable.get("created_at", ZonedDateTime.class); + + requireNonNull(id, "id must not be null"); + requireNonNull(createdAt, "createdAt must not be null"); + + return new DataEntity(id, 0, createdAt); + } + } +} diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java new file mode 100644 index 000000000..b7ac81a8b --- /dev/null +++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import org.junit.jupiter.api.condition.EnabledIf; + +/** + * Integration tests for MariaDB with server-preparing statements. + */ +@EnabledIf("envIsMariaDb10_5_1") +class MariaDbPrepareIntegrationTest extends MariaDbIntegrationTestSupport { + + MariaDbPrepareIntegrationTest() { + super(sql -> true); + } +} diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java new file mode 100644 index 000000000..0ab886c5f --- /dev/null +++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import org.junit.jupiter.api.condition.EnabledIf; + +/** + * Integration tests for MariaDB with client-preparing statements. + */ +@EnabledIf("envIsMariaDb10_5_1") +class MariaDbTextIntegrationTest extends MariaDbIntegrationTestSupport { + + MariaDbTextIntegrationTest() { + super(null); + } +} diff --git a/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java b/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java index e638fa38b..74bb3ee92 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/PrepareParametrizedStatementTest.java @@ -32,8 +32,6 @@ class PrepareParametrizedStatementTest implements StatementTestSupport>() {}.getType(); - Type enumSet = new TypeReference>() {}.getType(); + Type stringSet = new TypeReference>() { }.getType(); + Type enumSet = new TypeReference>() { }.getType(); testType(String.class, true, "SET('ONE','TWO','THREE')", null, "ONE,TWO,THREE", "ONE", "", "ONE,THREE"); @@ -268,7 +268,6 @@ void json() { testType(String.class, false, "JSON", null, "{\"data\": 1}", "[\"data\", 1]", "1", "null", "\"R2DBC\"", "2.56"); - } @Test @@ -407,7 +406,7 @@ void selectOne() { @Test void selectFromOtherDatabase() { complete(conn -> Flux.from(conn.createStatement("SELECT * FROM `information_schema`.`innodb_trx`") - .execute()) + .execute()) .flatMap(result -> result.map((row, metadata) -> row.get(0)))); } @@ -422,7 +421,7 @@ void multiQueries() { .flatMap(IntegrationTestSupport::extractRowsUpdated) .thenMany(Flux.range(0, 10)) .flatMap(it -> Flux.from(connection.createStatement("INSERT INTO test VALUES" + - "(DEFAULT,?,?,NOW(),NOW())") + "(DEFAULT,?,?,NOW(),NOW())") .bind(0, String.format("integration-test%d@mail.com", it)) .bind(1, "******") .execute())) @@ -443,7 +442,7 @@ void multiQueries() { @Test void consumePortion() { complete(connection -> Mono.from(connection.createStatement("CREATE TEMPORARY TABLE test" + - "(id INT PRIMARY KEY AUTO_INCREMENT,value INT)").execute()) + "(id INT PRIMARY KEY AUTO_INCREMENT,value INT)").execute()) .flatMap(IntegrationTestSupport::extractRowsUpdated) .then(Mono.from(connection.createStatement("INSERT INTO test(`value`) VALUES (1),(2),(3),(4),(5)") .execute())) @@ -453,8 +452,8 @@ void consumePortion() { .execute())) .flatMapMany(r -> r.map((row, metadata) -> row.get(0, Integer.TYPE))).take(3) .concatWith(Mono.from(connection.createStatement("SELECT value FROM test WHERE id > ?") - .bind(0, 0) - .execute()) + .bind(0, 0) + .execute()) .flatMapMany(r -> r.map((row, metadata) -> row.get(0, Integer.TYPE))).take(2)) .collectList() .doOnNext(it -> assertThat(it).isEqualTo(Arrays.asList(1, 2, 3, 1, 2)))); @@ -474,7 +473,7 @@ void ignoreResult() { .flatMap(IntegrationTestSupport::extractRowsUpdated) .thenMany(Flux.merge( Flux.from(connection.createStatement("SELECT value FROM test WHERE id > ?") - .bind(0, 0).execute()) + .bind(0, 0).execute()) .flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class))) .doOnNext(values::add), connection.createStatement("BAD GRAMMAR").execute() @@ -496,8 +495,8 @@ void ignoreResult() { void foundRows() { int value = 10; complete(connection -> Flux.from(connection.createStatement("CREATE TEMPORARY TABLE test" + - "(id INT PRIMARY KEY AUTO_INCREMENT,value INT)") - .execute()) + "(id INT PRIMARY KEY AUTO_INCREMENT,value INT)") + .execute()) .flatMap(IntegrationTestSupport::extractRowsUpdated) .thenMany(connection.createStatement("INSERT INTO test VALUES(DEFAULT,?)") .bind(0, value).execute()) @@ -522,11 +521,11 @@ void foundRows() { @Test void insertOnDuplicate() { complete(connection -> Flux.from(connection.createStatement("CREATE TEMPORARY TABLE test" + - "(id INT PRIMARY KEY,value INT)") - .execute()) + "(id INT PRIMARY KEY,value INT)") + .execute()) .flatMap(IntegrationTestSupport::extractRowsUpdated) .thenMany(connection.createStatement("INSERT INTO test VALUES(?,?) " + - "ON DUPLICATE KEY UPDATE value=?") + "ON DUPLICATE KEY UPDATE value=?") .bind(0, 1) .bind(1, 10) .bind(2, 20) @@ -540,7 +539,7 @@ void insertOnDuplicate() { .collectList() .doOnNext(it -> assertThat(it).isEqualTo(Collections.singletonList(10))) .thenMany(connection.createStatement("INSERT INTO test VALUES(?,?) " + - "ON DUPLICATE KEY UPDATE value=?") + "ON DUPLICATE KEY UPDATE value=?") .bind(0, 1) .bind(1, 10) .bind(2, 20) @@ -554,7 +553,7 @@ void insertOnDuplicate() { .collectList() .doOnNext(it -> assertThat(it).isEqualTo(Collections.singletonList(20))) .thenMany(connection.createStatement("INSERT INTO test VALUES(?,?) " + - "ON DUPLICATE KEY UPDATE value=?") + "ON DUPLICATE KEY UPDATE value=?") .bind(0, 1) .bind(1, 10) .bind(2, 20) @@ -625,22 +624,6 @@ private static Flux extractFirstInteger(Result result) { return Flux.from(result.map((row, metadata) -> row.get(0, Integer.class))); } - private static Flux> extractOk(Result result, Class type) { - return Flux.from(result.flatMap(segment -> { - try { - if (segment instanceof Result.UpdateCount && segment instanceof Result.RowSegment) { - long affected = ((Result.UpdateCount) segment).value(); - T t = Objects.requireNonNull(((Result.RowSegment) segment).row().get(0, type)); - return Mono.just(Tuples.of(affected, t)); - } else { - return Mono.empty(); - } - } finally { - ReferenceCountUtil.release(segment); - } - })); - } - @SuppressWarnings("unchecked") private static Flux> extractOptionalField(Result result, Type type) { if (type instanceof Class) { @@ -652,9 +635,9 @@ private static Flux> extractOptionalField(Result result, Type ty private static Mono testTimeDuration(Connection connection, Duration origin, LocalTime time) { return Mono.from(connection.createStatement("INSERT INTO test VALUES(DEFAULT,?)") - .bind(0, origin) - .returnGeneratedValues("id") - .execute()) + .bind(0, origin) + .returnGeneratedValues("id") + .execute()) .flatMapMany(QueryIntegrationTestSupport::extractFirstInteger) .concatMap(id -> connection.createStatement("SELECT value FROM test WHERE id=?") .bind(0, id) @@ -690,14 +673,13 @@ private static Mono testOne(MySqlConnection connection, Type type, boo } } - return Mono.from(insert.returnGeneratedValues("id") - .execute()) - .flatMap(result -> extractOk(result, Integer.class) - .collectList() - .map(ids -> { - assertThat(ids).hasSize(1).first().extracting(Tuple2::getT1).isEqualTo(1L); - return ids.get(0).getT2(); - })) + return Mono.from(insert.returnGeneratedValues("id").execute()) + .flatMapMany(QueryIntegrationTestSupport::extractFirstInteger) + .collectList() + .map(ids -> { + assertThat(ids).hasSize(1).first().isNotNull(); + return ids.get(0); + }) .flatMap(id -> Mono.from(connection.createStatement("SELECT value FROM test WHERE id=?") .bind(0, id) .execute())) diff --git a/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java index 65300af67..28f159c37 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/StatementTestSupport.java @@ -17,9 +17,12 @@ package io.asyncer.r2dbc.mysql; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.provider.ValueSource; import java.util.NoSuchElementException; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -33,7 +36,7 @@ interface StatementTestSupport { String SIMPLE = "SELECT * FROM test WHERE id = 1 AND name = 'Mirrors'"; - T makeInstance(String parametrizedSql, String simpleSql); + T makeInstance(boolean isMariaDB, String parametrizedSql, String simpleSql); boolean supportsBinding(); @@ -45,7 +48,7 @@ default int getFetchSize(T statement) throws IllegalAccessException { default void bind() { assertTrue(supportsBinding(), "Must skip test case #bind() for simple statements"); - T statement = makeInstance(PARAMETRIZED, SIMPLE); + T statement = makeInstance(false, PARAMETRIZED, SIMPLE); statement.bind(0, 1); statement.bind("id", 1); statement.bind(1, 1); @@ -54,7 +57,7 @@ default void bind() { @SuppressWarnings("ConstantConditions") @Test default void badBind() { - T statement = makeInstance(PARAMETRIZED, SIMPLE); + T statement = makeInstance(false, PARAMETRIZED, SIMPLE); if (supportsBinding()) { assertThrows(IllegalArgumentException.class, () -> statement.bind(0, null)); @@ -86,7 +89,7 @@ default void badBind() { default void bindNull() { assertTrue(supportsBinding(), "Must skip test case #bindNull() for simple statements"); - T statement = makeInstance(PARAMETRIZED, SIMPLE); + T statement = makeInstance(false, PARAMETRIZED, SIMPLE); statement.bindNull(0, Integer.class); statement.bindNull("id", Integer.class); statement.bindNull(1, Integer.class); @@ -95,7 +98,7 @@ default void bindNull() { @SuppressWarnings("ConstantConditions") @Test default void badBindNull() { - T statement = makeInstance(PARAMETRIZED, SIMPLE); + T statement = makeInstance(false, PARAMETRIZED, SIMPLE); if (supportsBinding()) { assertThrows(IllegalArgumentException.class, () -> statement.bindNull(0, null)); @@ -125,7 +128,7 @@ default void badBindNull() { @Test default void add() { - T statement = makeInstance(PARAMETRIZED, SIMPLE); + T statement = makeInstance(false, PARAMETRIZED, SIMPLE); if (!supportsBinding()) { statement.add(); @@ -143,38 +146,91 @@ default void add() { default void badAdd() { assertTrue(supportsBinding(), "Must skip test case #badAdd() for simple statements"); - T statement = makeInstance(PARAMETRIZED, SIMPLE); + T statement = makeInstance(false, PARAMETRIZED, SIMPLE); statement.bind(0, 1); assertThrows(IllegalStateException.class, statement::add); } @Test - default void returnGeneratedValues() { - T statement = makeInstance(PARAMETRIZED, SIMPLE); - - statement.returnGeneratedValues(); - assertEquals(statement.generatedKeyName, "LAST_INSERT_ID"); - statement.returnGeneratedValues("generated"); - assertEquals(statement.generatedKeyName, "generated"); - statement.returnGeneratedValues("generate`d"); - assertEquals(statement.generatedKeyName, "generate`d"); + default void mySqlReturnGeneratedValues() { + T s = makeInstance(false, PARAMETRIZED, SIMPLE); + + s.returnGeneratedValues(); + + assertThat(s.syntheticKeyName()).isEqualTo("LAST_INSERT_ID"); + assertThat(s.returningIdentifiers()).isEqualTo(""); + + s.returnGeneratedValues("generated"); + + assertThat(s.syntheticKeyName()).isEqualTo("generated"); + assertThat(s.returningIdentifiers()).isEqualTo(""); + + s.returnGeneratedValues("generate`d"); + + assertThat(s.syntheticKeyName()).isEqualTo("generate`d"); + assertThat(s.returningIdentifiers()).isEqualTo(""); + } + + @Test + default void mariaDbReturnGeneratedValues() { + T s = makeInstance(true, PARAMETRIZED, SIMPLE); + + s.returnGeneratedValues(); + + assertThat(s.syntheticKeyName()).isNull(); + assertThat(s.returningIdentifiers()).isEqualTo("*"); + + s.returnGeneratedValues("generated"); + + assertThat(s.syntheticKeyName()).isNull(); + assertThat(s.returningIdentifiers()).isEqualTo("`generated`"); + + s.returnGeneratedValues("generate`d"); + + assertThat(s.syntheticKeyName()).isNull(); + assertThat(s.returningIdentifiers()).isEqualTo("`generate``d`"); + + s.returnGeneratedValues("id", "name"); + + assertThat(s.syntheticKeyName()).isNull(); + assertThat(s.returningIdentifiers()).isEqualTo("`id`,`name`"); + + s.returnGeneratedValues("id", "name", "desc", "created_at"); + + assertThat(s.syntheticKeyName()).isNull(); + assertThat(s.returningIdentifiers()).isEqualTo("`id`,`name`,`desc`,`created_at`"); } @SuppressWarnings("ConstantConditions") @Test default void badReturnGeneratedValues() { - T statement = makeInstance(PARAMETRIZED, SIMPLE); + T s = makeInstance(false, PARAMETRIZED, SIMPLE); + + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues((String) null)); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues((String[]) null)); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("")); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("", "")); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", "name")); + } - assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String) null)); - assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues((String[]) null)); - assertThrows(IllegalArgumentException.class, () -> statement.returnGeneratedValues("")); - assertThrows(IllegalArgumentException.class, () -> - statement.returnGeneratedValues("generated", "names")); + @SuppressWarnings("ConstantConditions") + @Test + default void mariaDbBadReturnGeneratedValues() { + T s = makeInstance(true, PARAMETRIZED, SIMPLE); + + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues((String) null)); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues((String[]) null)); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("")); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("", "")); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", "")); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", null)); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", "", "name")); + assertThatIllegalArgumentException().isThrownBy(() -> s.returnGeneratedValues("id", null, "name")); } @Test default void fetchSize() throws IllegalAccessException { - T statement = makeInstance(PARAMETRIZED, SIMPLE); + T statement = makeInstance(false, PARAMETRIZED, SIMPLE); assertEquals(0, getFetchSize(statement), "Must skip test case #fetchSize() for text-based queries"); for (int i = 1; i <= 10; ++i) { @@ -192,7 +248,7 @@ default void fetchSize() throws IllegalAccessException { @Test default void badFetchSize() { - T statement = makeInstance(PARAMETRIZED, SIMPLE); + T statement = makeInstance(false, PARAMETRIZED, SIMPLE); assertThrows(IllegalArgumentException.class, () -> statement.fetchSize(-1)); assertThrows(IllegalArgumentException.class, () -> statement.fetchSize(-10)); diff --git a/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java b/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java index 5ad9ab2c3..38646ec3f 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/TextParametrizedStatementTest.java @@ -29,8 +29,6 @@ class TextParametrizedStatementTest implements StatementTestSupport implements CodecTest @Override public CodecContext context() { - return ConnectionContextTest.mock(ENCODE_SERVER_ZONE); + return ConnectionContextTest.mock(false, ENCODE_SERVER_ZONE); } protected final String toText(Temporal dateTime) { diff --git a/src/test/java/io/asyncer/r2dbc/mysql/codec/TimeCodecTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/codec/TimeCodecTestSupport.java index ff0a572b3..88ec88244 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/codec/TimeCodecTestSupport.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/codec/TimeCodecTestSupport.java @@ -55,7 +55,7 @@ abstract class TimeCodecTestSupport implements CodecTestSupp @Override public CodecContext context() { - return ConnectionContextTest.mock(ENCODE_SERVER_ZONE); + return ConnectionContextTest.mock(false, ENCODE_SERVER_ZONE); } protected final String toText(Temporal time) {