Skip to content

Commit

Permalink
Add row count and README for RETURNING clause
Browse files Browse the repository at this point in the history
  • Loading branch information
mirromutth committed Jan 18, 2024
1 parent f2e3c2d commit b7e2f4a
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 37 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ This driver provides the following features:
- [x] Transactions with savepoint.
- [x] Native ping command that can be verifying when argument is `ValidationDepth.REMOTE`
- [x] Extensible, e.g. extend built-in `Codec`(s).
- [x] MariaDB `RETURNING` clause.

## Maintainer

Expand Down Expand Up @@ -543,10 +544,12 @@ If you want to raise an issue, please follow the recommendations below:
## Before use

- The MySQL data fields encoded by index-based natively, get fields by an index will have **better** performance than get by column name.
- Each `Result` should be used (call `getRowsUpdated` or `map`, even table definition), can **NOT** just ignore any `Result`, otherwise inbound stream is unable to align. (like `ResultSet.close` in jdbc, `Result` auto-close after used by once)
- Each `Result` should be used (call `getRowsUpdated` or `map`/`flatMap`, even table definition), can **NOT** just ignore any `Result`, otherwise inbound stream is unable to align. (like `ResultSet.close` in jdbc, `Result` auto-close after used by once)
- The MySQL server does not **actively** return time zone when query `DATETIME` or `TIMESTAMP`, this driver does not attempt time zone conversion. That means should always use `LocalDateTime` for SQL type `DATETIME` or `TIMESTAMP`. Execute `SHOW VARIABLES LIKE '%time_zone%'` to get more information.
- Should not turn-on the `trace` log level unless debugging. Otherwise, the security information may be exposed through `ByteBuf` dump.
- If `Statement` bound `returnGeneratedValues`, the `Result` of the `Statement` can be called both: `getRowsUpdated` to get affected rows, and `map` to get last inserted ID.
- If `Statement` bound `returnGeneratedValues`, the `Result` of the `Statement` can be called both of `getRowsUpdated` and `map`/`flatMap`.
- If server is not MariaDB: `returnGeneratedValues` can only be called with one or zero arguments, and the `Result` will contain the last inserted id.
- If server is MariaDB 10.5.1 and above: the statement will attempt to use `RETURNING` clause, zero arguments will make the statement like `... RETURNING *`.
- The MySQL may be not support well for searching rows by a binary field, like `BIT` and `JSON`
- `BIT`: cannot select 'BIT(64)' with value greater than 'Long.MAX_VALUE' (or equivalent in binary)
- `JSON`: different MySQL may have different serialization formats, e.g. MariaDB and MySQL
Expand Down
36 changes: 25 additions & 11 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;

import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand Down Expand Up @@ -227,34 +228,35 @@ protected void deallocate() {

private static class MySqlUpdateCount implements UpdateCount {

protected final OkMessage message;
protected final long rows;

private MySqlUpdateCount(OkMessage message) {
this.message = message;
}
private MySqlUpdateCount(long rows) { this.rows = rows; }

@Override
public long value() {
return message.getAffectedRows();
return rows;
}
}

private static final class MySqlOkSegment extends MySqlUpdateCount implements RowSegment {

private final long lastInsertId;

private final Codecs codecs;

private final String keyName;

private MySqlOkSegment(OkMessage message, Codecs codecs, String keyName) {
super(message);
private MySqlOkSegment(long rows, long lastInsertId, Codecs codecs, String keyName) {
super(rows);

this.lastInsertId = lastInsertId;
this.codecs = codecs;
this.keyName = keyName;
}

@Override
public Row row() {
return new InsertSyntheticRow(codecs, keyName, message.getLastInsertId());
return new InsertSyntheticRow(codecs, keyName, lastInsertId);
}
}

Expand All @@ -269,6 +271,8 @@ private static final class MySqlSegments implements BiConsumer<ServerMessage, Sy
@Nullable
private final String syntheticKeyName;

private final AtomicLong rowCount = new AtomicLong(0);

private MySqlRowMetadata rowMetadata;

private MySqlSegments(boolean binary, Codecs codecs, ConnectionContext context,
Expand All @@ -282,6 +286,9 @@ private MySqlSegments(boolean binary, Codecs codecs, ConnectionContext context,
@Override
public void accept(ServerMessage message, SynchronousSink<Segment> sink) {
if (message instanceof RowMessage) {
// Updated rows can be identified either by OK or rows in case of RETURNING
rowCount.getAndIncrement();

MySqlRowMetadata metadata = this.rowMetadata;

if (metadata == null) {
Expand All @@ -308,10 +315,17 @@ public void accept(ServerMessage message, SynchronousSink<Segment> sink) {

this.rowMetadata = MySqlRowMetadata.create(metadataMessages);
} else if (message instanceof OkMessage) {
Segment segment = syntheticKeyName == null ? new MySqlUpdateCount((OkMessage) message) :
new MySqlOkSegment((OkMessage) message, codecs, syntheticKeyName);
OkMessage msg = (OkMessage) message;

if (MySqlStatementSupport.supportReturning(context) && msg.isEndOfRows()) {
sink.next(new MySqlUpdateCount(rowCount.getAndSet(0)));
} else {
long rows = msg.getAffectedRows();
Segment segment = syntheticKeyName == null ? new MySqlUpdateCount(rows) :
new MySqlOkSegment(rows, msg.getLastInsertId(), codecs, syntheticKeyName);

sink.next(segment);
sink.next(segment);
}
} else if (message instanceof ErrorMessage) {
sink.next(new MySqlMessage((ErrorMessage) message));
} else {
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/io/asyncer/r2dbc/mysql/MySqlStatementSupport.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public final MySqlStatement returnGeneratedValues(String... columns) {

if (len == 0) {
this.generatedColumns = InternalArrays.EMPTY_STRINGS;
} else if (len == 1 || supportReturning()) {
} else if (len == 1 || supportReturning(context)) {
String[] result = new String[len];

for (int i = 0; i < len; ++i) {
Expand All @@ -63,7 +63,7 @@ public final MySqlStatement returnGeneratedValues(String... columns) {
this.generatedColumns = result;
} else {
String db = context.isMariaDb() ? "MariaDB 10.5.0 or below" : "MySQL";
throw new IllegalArgumentException(db + " supports only LAST_INSERT_ID instead of RETURNING");
throw new IllegalArgumentException(db + " can have only one column");
}

return this;
Expand All @@ -80,7 +80,7 @@ final String syntheticKeyName() {
String[] columns = this.generatedColumns;

// MariaDB should use `RETURNING` clause instead.
if (columns == null || supportReturning()) {
if (columns == null || supportReturning(this.context)) {
return null;
}

Expand All @@ -94,7 +94,7 @@ final String syntheticKeyName() {
final String returningIdentifiers() {
String[] columns = this.generatedColumns;

if (columns == null || !supportReturning()) {
if (columns == null || !supportReturning(context)) {
return "";
}

Expand All @@ -111,7 +111,7 @@ final String returningIdentifiers() {
return joiner.toString();
}

private boolean supportReturning() {
static boolean supportReturning(ConnectionContext context) {
return context.isMariaDb() && context.getServerVersion().isGreaterThanOrEqualTo(MARIA_10_5_1);
}
}
2 changes: 1 addition & 1 deletion src/main/java/io/asyncer/r2dbc/mysql/codec/Codecs.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ <T> T decode(FieldValue value, MySqlColumnMetadata metadata, ParameterizedType t
CodecContext context);

/**
* Decode the last inserted ID from {@code OkMessage} as a specified {@link ParameterizedType type}.
* Decode the last inserted ID from {@code OkMessage} as a specified {@link Class type}.
*
* @param <T> the generic result type.
* @param value the last inserted ID.
Expand Down
43 changes: 29 additions & 14 deletions src/main/java/io/asyncer/r2dbc/mysql/message/server/OkMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public final class OkMessage implements WarningMessage, ServerStatusMessage, Com

private static final int MIN_SIZE = 7;

private final boolean isEndOfRows;

private final long affectedRows;

/**
Expand All @@ -49,15 +51,20 @@ public final class OkMessage implements WarningMessage, ServerStatusMessage, Com

private final String information;

private OkMessage(long affectedRows, long lastInsertId, short serverStatuses, int warnings,
String information) {
private OkMessage(boolean isEndOfRows, long affectedRows, long lastInsertId, short serverStatuses,
int warnings, String information) {
this.isEndOfRows = isEndOfRows;
this.affectedRows = affectedRows;
this.lastInsertId = lastInsertId;
this.serverStatuses = serverStatuses;
this.warnings = warnings;
this.information = requireNonNull(information, "information must not be null");
}

public boolean isEndOfRows() {
return isEndOfRows;
}

public long getAffectedRows() {
return affectedRows;
}
Expand Down Expand Up @@ -92,7 +99,8 @@ public boolean equals(Object o) {

OkMessage okMessage = (OkMessage) o;

return affectedRows == okMessage.affectedRows &&
return isEndOfRows == okMessage.isEndOfRows &&
affectedRows == okMessage.affectedRows &&
lastInsertId == okMessage.lastInsertId &&
serverStatuses == okMessage.serverStatuses &&
warnings == okMessage.warnings &&
Expand All @@ -101,7 +109,8 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
int result = (int) (affectedRows ^ (affectedRows >>> 32));
int result = (isEndOfRows ? 1 : 0);
result = 31 * result + (int) (affectedRows ^ (affectedRows >>> 32));
result = 31 * result + (int) (lastInsertId ^ (lastInsertId >>> 32));
result = 31 * result + serverStatuses;
result = 31 * result + warnings;
Expand All @@ -111,21 +120,26 @@ public int hashCode() {
@Override
public String toString() {
if (warnings == 0) {
return "OkMessage{affectedRows=" + Long.toUnsignedString(affectedRows) + ", lastInsertId=" +
Long.toUnsignedString(lastInsertId) + ", serverStatuses=" + Integer.toHexString(serverStatuses) +
return "OkMessage{isEndOfRows=" + isEndOfRows +
", affectedRows=" + Long.toUnsignedString(affectedRows) +
", lastInsertId=" + Long.toUnsignedString(lastInsertId) +
", serverStatuses=" + Integer.toHexString(serverStatuses) +
", information='" + information + "'}";
}

return "OkMessage{affectedRows=" + Long.toUnsignedString(affectedRows) + ", lastInsertId=" +
Long.toUnsignedString(lastInsertId) + ", serverStatuses=" + Integer.toHexString(serverStatuses) +
", warnings=" + warnings + ", information='" + information + "'}";
return "OkMessage{isEndOfRows=" + isEndOfRows +
", affectedRows=" + Long.toUnsignedString(affectedRows) +
", lastInsertId=" + Long.toUnsignedString(lastInsertId) +
", serverStatuses=" + Integer.toHexString(serverStatuses) +
", warnings=" + warnings +
", information='" + information + "'}";
}

static boolean isValidSize(int bytes) {
return bytes >= MIN_SIZE;
}

static OkMessage decode(ByteBuf buf, ConnectionContext context) {
static OkMessage decode(boolean isEndOfRows, ByteBuf buf, ConnectionContext context) {
buf.skipBytes(1); // OK message header, 0x00 or 0xFE

Capability capability = context.getCapability();
Expand All @@ -149,8 +163,8 @@ static OkMessage decode(ByteBuf buf, ConnectionContext context) {
int sizeAfterVarInt = VarIntUtils.checkNextVarInt(buf);

if (sizeAfterVarInt < 0) {
return new OkMessage(affectedRows, lastInsertId, serverStatuses, warnings,
buf.toString(charset));
return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses,
warnings, buf.toString(charset));
}

int readerIndex = buf.readerIndex();
Expand All @@ -165,10 +179,11 @@ static OkMessage decode(ByteBuf buf, ConnectionContext context) {
}

// Ignore session track, it is not human-readable and useless for R2DBC client.
return new OkMessage(affectedRows, lastInsertId, serverStatuses, warnings, information);
return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses, warnings,
information);
}

// Maybe have no human-readable message
return new OkMessage(affectedRows, lastInsertId, serverStatuses, warnings, "");
return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses, warnings, "");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ private static ServerMessage decodeCommandMessage(ByteBuf buf, ConnectionContext
return ErrorMessage.decode(buf);
case OK:
if (OkMessage.isValidSize(buf.readableBytes())) {
return OkMessage.decode(buf, context);
return OkMessage.decode(false, buf, context);
}

break;
Expand All @@ -209,7 +209,7 @@ private static ServerMessage decodeCommandMessage(ByteBuf buf, ConnectionContext
// so if readable bytes upper than 7, it means if it is column count,
// column count is already upper than (1 << 24) - 1 = 16777215, it is impossible.
// So it must be OK message, not be column count.
return OkMessage.decode(buf, context);
return OkMessage.decode(false, buf, context);
} else if (EofMessage.isValidSize(byteSize)) {
return EofMessage.decode(buf);
}
Expand All @@ -231,7 +231,7 @@ private static ServerMessage decodeLogin(int envelopeId, ByteBuf buf, Connection
switch (header) {
case OK:
if (OkMessage.isValidSize(buf.readableBytes())) {
return OkMessage.decode(buf, context);
return OkMessage.decode(false, buf, context);
}

break;
Expand Down Expand Up @@ -333,7 +333,7 @@ private static ServerMessage decodeRow(List<ByteBuf> buffers, ByteBuf firstBuf,
ByteBuf combined = NettyBufferUtils.composite(buffers);

try {
return OkMessage.decode(combined, context);
return OkMessage.decode(true, combined, context);
} finally {
combined.release();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ void partialReturning() {
);
}

@Test
void returningGetRowUpdated() {
complete(conn -> conn.createStatement("CREATE TEMPORARY TABLE test(" +
"id INT NOT NULL AUTO_INCREMENT PRIMARY KEY,value INT NOT NULL)")
.execute()
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.thenMany(conn.createStatement("INSERT INTO test(value) VALUES (?),(?)")
.bind(0, 2)
.bind(1, 4)
.returnGeneratedValues()
.execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.doOnNext(it -> assertThat(it).isEqualTo(2)));
}

private static final class DataEntity {

private final int id;
Expand Down

0 comments on commit b7e2f4a

Please sign in to comment.