diff --git a/README.md b/README.md index 830432359..ab9fb6b5b 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ This driver provides the following features: - [x] Transactions with savepoint. - [x] Native ping command that can be verifying when argument is `ValidationDepth.REMOTE` - [x] Extensible, e.g. extend built-in `Codec`(s). +- [x] MariaDB `RETURNING` clause. ## Maintainer @@ -543,10 +544,12 @@ If you want to raise an issue, please follow the recommendations below: ## Before use - The MySQL data fields encoded by index-based natively, get fields by an index will have **better** performance than get by column name. -- Each `Result` should be used (call `getRowsUpdated` or `map`, even table definition), can **NOT** just ignore any `Result`, otherwise inbound stream is unable to align. (like `ResultSet.close` in jdbc, `Result` auto-close after used by once) +- Each `Result` should be used (call `getRowsUpdated` or `map`/`flatMap`, even table definition), can **NOT** just ignore any `Result`, otherwise inbound stream is unable to align. (like `ResultSet.close` in jdbc, `Result` auto-close after used by once) - The MySQL server does not **actively** return time zone when query `DATETIME` or `TIMESTAMP`, this driver does not attempt time zone conversion. That means should always use `LocalDateTime` for SQL type `DATETIME` or `TIMESTAMP`. Execute `SHOW VARIABLES LIKE '%time_zone%'` to get more information. - Should not turn-on the `trace` log level unless debugging. Otherwise, the security information may be exposed through `ByteBuf` dump. -- If `Statement` bound `returnGeneratedValues`, the `Result` of the `Statement` can be called both: `getRowsUpdated` to get affected rows, and `map` to get last inserted ID. +- If `Statement` bound `returnGeneratedValues`, the `Result` of the `Statement` can be called both of `getRowsUpdated` and `map`/`flatMap`. + - If server is not MariaDB: `returnGeneratedValues` can only be called with one or zero arguments, and the `Result` will contain the last inserted id. + - If server is MariaDB 10.5.1 and above: the statement will attempt to use `RETURNING` clause, zero arguments will make the statement like `... RETURNING *`. - The MySQL may be not support well for searching rows by a binary field, like `BIT` and `JSON` - `BIT`: cannot select 'BIT(64)' with value greater than 'Long.MAX_VALUE' (or equivalent in binary) - `JSON`: different MySQL may have different serialization formats, e.g. MariaDB and MySQL diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java index 943987d1e..749086572 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java @@ -40,6 +40,7 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.SynchronousSink; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; @@ -227,34 +228,35 @@ protected void deallocate() { private static class MySqlUpdateCount implements UpdateCount { - protected final OkMessage message; + protected final long rows; - private MySqlUpdateCount(OkMessage message) { - this.message = message; - } + private MySqlUpdateCount(long rows) { this.rows = rows; } @Override public long value() { - return message.getAffectedRows(); + return rows; } } private static final class MySqlOkSegment extends MySqlUpdateCount implements RowSegment { + private final long lastInsertId; + private final Codecs codecs; private final String keyName; - private MySqlOkSegment(OkMessage message, Codecs codecs, String keyName) { - super(message); + private MySqlOkSegment(long rows, long lastInsertId, Codecs codecs, String keyName) { + super(rows); + this.lastInsertId = lastInsertId; this.codecs = codecs; this.keyName = keyName; } @Override public Row row() { - return new InsertSyntheticRow(codecs, keyName, message.getLastInsertId()); + return new InsertSyntheticRow(codecs, keyName, lastInsertId); } } @@ -269,6 +271,8 @@ private static final class MySqlSegments implements BiConsumer sink) { if (message instanceof RowMessage) { + // Updated rows can be identified either by OK or rows in case of RETURNING + rowCount.getAndIncrement(); + MySqlRowMetadata metadata = this.rowMetadata; if (metadata == null) { @@ -308,10 +315,17 @@ public void accept(ServerMessage message, SynchronousSink sink) { this.rowMetadata = MySqlRowMetadata.create(metadataMessages); } else if (message instanceof OkMessage) { - Segment segment = syntheticKeyName == null ? new MySqlUpdateCount((OkMessage) message) : - new MySqlOkSegment((OkMessage) message, codecs, syntheticKeyName); + OkMessage msg = (OkMessage) message; + + if (MySqlStatementSupport.supportReturning(context) && msg.isEndOfRows()) { + sink.next(new MySqlUpdateCount(rowCount.getAndSet(0))); + } else { + long rows = msg.getAffectedRows(); + Segment segment = syntheticKeyName == null ? new MySqlUpdateCount(rows) : + new MySqlOkSegment(rows, msg.getLastInsertId(), codecs, syntheticKeyName); - sink.next(segment); + sink.next(segment); + } } else if (message instanceof ErrorMessage) { sink.next(new MySqlMessage((ErrorMessage) message)); } else { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java index 0ad86b4db..16520bbe5 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java @@ -52,7 +52,7 @@ public final MySqlStatement returnGeneratedValues(String... columns) { if (len == 0) { this.generatedColumns = InternalArrays.EMPTY_STRINGS; - } else if (len == 1 || supportReturning()) { + } else if (len == 1 || supportReturning(context)) { String[] result = new String[len]; for (int i = 0; i < len; ++i) { @@ -63,7 +63,7 @@ public final MySqlStatement returnGeneratedValues(String... columns) { this.generatedColumns = result; } else { String db = context.isMariaDb() ? "MariaDB 10.5.0 or below" : "MySQL"; - throw new IllegalArgumentException(db + " supports only LAST_INSERT_ID instead of RETURNING"); + throw new IllegalArgumentException(db + " can have only one column"); } return this; @@ -80,7 +80,7 @@ final String syntheticKeyName() { String[] columns = this.generatedColumns; // MariaDB should use `RETURNING` clause instead. - if (columns == null || supportReturning()) { + if (columns == null || supportReturning(this.context)) { return null; } @@ -94,7 +94,7 @@ final String syntheticKeyName() { final String returningIdentifiers() { String[] columns = this.generatedColumns; - if (columns == null || !supportReturning()) { + if (columns == null || !supportReturning(context)) { return ""; } @@ -111,7 +111,7 @@ final String returningIdentifiers() { return joiner.toString(); } - private boolean supportReturning() { + static boolean supportReturning(ConnectionContext context) { return context.isMariaDb() && context.getServerVersion().isGreaterThanOrEqualTo(MARIA_10_5_1); } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/codec/Codecs.java b/src/main/java/io/asyncer/r2dbc/mysql/codec/Codecs.java index 7af324e13..aff4799b0 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/codec/Codecs.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/codec/Codecs.java @@ -62,7 +62,7 @@ T decode(FieldValue value, MySqlColumnMetadata metadata, ParameterizedType t CodecContext context); /** - * Decode the last inserted ID from {@code OkMessage} as a specified {@link ParameterizedType type}. + * Decode the last inserted ID from {@code OkMessage} as a specified {@link Class type}. * * @param the generic result type. * @param value the last inserted ID. diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/OkMessage.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/OkMessage.java index 52e3bce2f..da937fe17 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/OkMessage.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/OkMessage.java @@ -36,6 +36,8 @@ public final class OkMessage implements WarningMessage, ServerStatusMessage, Com private static final int MIN_SIZE = 7; + private final boolean isEndOfRows; + private final long affectedRows; /** @@ -49,8 +51,9 @@ public final class OkMessage implements WarningMessage, ServerStatusMessage, Com private final String information; - private OkMessage(long affectedRows, long lastInsertId, short serverStatuses, int warnings, - String information) { + private OkMessage(boolean isEndOfRows, long affectedRows, long lastInsertId, short serverStatuses, + int warnings, String information) { + this.isEndOfRows = isEndOfRows; this.affectedRows = affectedRows; this.lastInsertId = lastInsertId; this.serverStatuses = serverStatuses; @@ -58,6 +61,10 @@ private OkMessage(long affectedRows, long lastInsertId, short serverStatuses, in this.information = requireNonNull(information, "information must not be null"); } + public boolean isEndOfRows() { + return isEndOfRows; + } + public long getAffectedRows() { return affectedRows; } @@ -92,7 +99,8 @@ public boolean equals(Object o) { OkMessage okMessage = (OkMessage) o; - return affectedRows == okMessage.affectedRows && + return isEndOfRows == okMessage.isEndOfRows && + affectedRows == okMessage.affectedRows && lastInsertId == okMessage.lastInsertId && serverStatuses == okMessage.serverStatuses && warnings == okMessage.warnings && @@ -101,7 +109,8 @@ public boolean equals(Object o) { @Override public int hashCode() { - int result = (int) (affectedRows ^ (affectedRows >>> 32)); + int result = (isEndOfRows ? 1 : 0); + result = 31 * result + (int) (affectedRows ^ (affectedRows >>> 32)); result = 31 * result + (int) (lastInsertId ^ (lastInsertId >>> 32)); result = 31 * result + serverStatuses; result = 31 * result + warnings; @@ -111,21 +120,26 @@ public int hashCode() { @Override public String toString() { if (warnings == 0) { - return "OkMessage{affectedRows=" + Long.toUnsignedString(affectedRows) + ", lastInsertId=" + - Long.toUnsignedString(lastInsertId) + ", serverStatuses=" + Integer.toHexString(serverStatuses) + + return "OkMessage{isEndOfRows=" + isEndOfRows + + ", affectedRows=" + Long.toUnsignedString(affectedRows) + + ", lastInsertId=" + Long.toUnsignedString(lastInsertId) + + ", serverStatuses=" + Integer.toHexString(serverStatuses) + ", information='" + information + "'}"; } - return "OkMessage{affectedRows=" + Long.toUnsignedString(affectedRows) + ", lastInsertId=" + - Long.toUnsignedString(lastInsertId) + ", serverStatuses=" + Integer.toHexString(serverStatuses) + - ", warnings=" + warnings + ", information='" + information + "'}"; + return "OkMessage{isEndOfRows=" + isEndOfRows + + ", affectedRows=" + Long.toUnsignedString(affectedRows) + + ", lastInsertId=" + Long.toUnsignedString(lastInsertId) + + ", serverStatuses=" + Integer.toHexString(serverStatuses) + + ", warnings=" + warnings + + ", information='" + information + "'}"; } static boolean isValidSize(int bytes) { return bytes >= MIN_SIZE; } - static OkMessage decode(ByteBuf buf, ConnectionContext context) { + static OkMessage decode(boolean isEndOfRows, ByteBuf buf, ConnectionContext context) { buf.skipBytes(1); // OK message header, 0x00 or 0xFE Capability capability = context.getCapability(); @@ -149,8 +163,8 @@ static OkMessage decode(ByteBuf buf, ConnectionContext context) { int sizeAfterVarInt = VarIntUtils.checkNextVarInt(buf); if (sizeAfterVarInt < 0) { - return new OkMessage(affectedRows, lastInsertId, serverStatuses, warnings, - buf.toString(charset)); + return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses, + warnings, buf.toString(charset)); } int readerIndex = buf.readerIndex(); @@ -165,10 +179,11 @@ static OkMessage decode(ByteBuf buf, ConnectionContext context) { } // Ignore session track, it is not human-readable and useless for R2DBC client. - return new OkMessage(affectedRows, lastInsertId, serverStatuses, warnings, information); + return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses, warnings, + information); } // Maybe have no human-readable message - return new OkMessage(affectedRows, lastInsertId, serverStatuses, warnings, ""); + return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses, warnings, ""); } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoder.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoder.java index 917d3485f..1f7408e7b 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoder.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoder.java @@ -196,7 +196,7 @@ private static ServerMessage decodeCommandMessage(ByteBuf buf, ConnectionContext return ErrorMessage.decode(buf); case OK: if (OkMessage.isValidSize(buf.readableBytes())) { - return OkMessage.decode(buf, context); + return OkMessage.decode(false, buf, context); } break; @@ -209,7 +209,7 @@ private static ServerMessage decodeCommandMessage(ByteBuf buf, ConnectionContext // so if readable bytes upper than 7, it means if it is column count, // column count is already upper than (1 << 24) - 1 = 16777215, it is impossible. // So it must be OK message, not be column count. - return OkMessage.decode(buf, context); + return OkMessage.decode(false, buf, context); } else if (EofMessage.isValidSize(byteSize)) { return EofMessage.decode(buf); } @@ -231,7 +231,7 @@ private static ServerMessage decodeLogin(int envelopeId, ByteBuf buf, Connection switch (header) { case OK: if (OkMessage.isValidSize(buf.readableBytes())) { - return OkMessage.decode(buf, context); + return OkMessage.decode(false, buf, context); } break; @@ -333,7 +333,7 @@ private static ServerMessage decodeRow(List buffers, ByteBuf firstBuf, ByteBuf combined = NettyBufferUtils.composite(buffers); try { - return OkMessage.decode(combined, context); + return OkMessage.decode(true, combined, context); } finally { combined.release(); } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java index 610b4ecd2..f2291fcb5 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java @@ -118,6 +118,21 @@ void partialReturning() { ); } + @Test + void returningGetRowUpdated() { + complete(conn -> conn.createStatement("CREATE TEMPORARY TABLE test(" + + "id INT NOT NULL AUTO_INCREMENT PRIMARY KEY,value INT NOT NULL)") + .execute() + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .thenMany(conn.createStatement("INSERT INTO test(value) VALUES (?),(?)") + .bind(0, 2) + .bind(1, 4) + .returnGeneratedValues() + .execute()) + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .doOnNext(it -> assertThat(it).isEqualTo(2))); + } + private static final class DataEntity { private final int id;