Skip to content

Commit

Permalink
Add support for decoding session state info
Browse files Browse the repository at this point in the history
  • Loading branch information
mirromutth committed Feb 13, 2024
1 parent b0d00a0 commit fbdd900
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 18 deletions.
12 changes: 10 additions & 2 deletions src/main/java/io/asyncer/r2dbc/mysql/Capability.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ public final class Capability {
private static final long VAR_INT_SIZED_AUTH = 1L << 21;

// private static final long HANDLE_EXPIRED_PASSWORD = 1L << 22; // Client can handle expired passwords.
// private static final long SESSION_TRACK = 1L << 23;

/**
* Server can send session state information in the OK packet.
*/
private static final long SESSION_TRACK = 1L << 23;

/**
* The MySQL server marks the EOF message as deprecated and use OK message instead.
Expand All @@ -171,7 +175,7 @@ public final class Capability {
private static final long ALL_SUPPORTED = CLIENT_MYSQL | FOUND_ROWS | LONG_FLAG | CONNECT_WITH_DB |
NO_SCHEMA | COMPRESS | LOCAL_FILES | IGNORE_SPACE | PROTOCOL_41 | INTERACTIVE | SSL |
TRANSACTIONS | SECURE_SALT | MULTI_STATEMENTS | MULTI_RESULTS | PS_MULTI_RESULTS |
PLUGIN_AUTH | CONNECT_ATTRS | VAR_INT_SIZED_AUTH | DEPRECATE_EOF;
PLUGIN_AUTH | CONNECT_ATTRS | VAR_INT_SIZED_AUTH | SESSION_TRACK | DEPRECATE_EOF;

private final long bitmap;

Expand Down Expand Up @@ -377,6 +381,10 @@ void disableSsl() {
this.bitmap &= ~SSL;
}

void disableSessionTrack() {
this.bitmap &= ~SESSION_TRACK;
}

void disableConnectAttributes() {
this.bitmap &= ~CONNECT_ATTRS;
}
Expand Down
1 change: 1 addition & 0 deletions src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ private AuthResponse createAuthResponse(int envelopeId, String phase) {
private Capability clientCapability(Capability serverCapability) {
Capability.Builder builder = serverCapability.mutate();

builder.disableSessionTrack();
builder.disableDatabasePinned();
builder.disableCompression();
builder.disableIgnoreAmbiguitySpace();
Expand Down
105 changes: 89 additions & 16 deletions src/main/java/io/asyncer/r2dbc/mysql/message/server/OkMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.internal.util.VarIntUtils;
import io.netty.buffer.ByteBuf;
import org.jetbrains.annotations.Nullable;

import java.nio.charset.Charset;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;

Expand All @@ -34,6 +38,8 @@
*/
public final class OkMessage implements WarningMessage, ServerStatusMessage, CompleteMessage {

private static final int SESSION_TRACK_SYSTEM_VARIABLES = 0;

private static final int MIN_SIZE = 7;

private final boolean isEndOfRows;
Expand All @@ -51,14 +57,17 @@ public final class OkMessage implements WarningMessage, ServerStatusMessage, Com

private final String information;

private final Map<String, String> systemVariables;

private OkMessage(boolean isEndOfRows, long affectedRows, long lastInsertId, short serverStatuses,
int warnings, String information) {
int warnings, String information, Map<String, String> systemVariables) {
this.isEndOfRows = isEndOfRows;
this.affectedRows = affectedRows;
this.lastInsertId = lastInsertId;
this.serverStatuses = serverStatuses;
this.warnings = warnings;
this.information = requireNonNull(information, "information must not be null");
this.systemVariables = requireNonNull(systemVariables, "systemVariables must not be null");
}

public boolean isEndOfRows() {
Expand All @@ -83,6 +92,11 @@ public int getWarnings() {
return warnings;
}

@Nullable
public String getSystemVariable(String key) {
return systemVariables.get(key);
}

@Override
public boolean isDone() {
return (serverStatuses & ServerStatuses.MORE_RESULTS_EXISTS) == 0;
Expand All @@ -104,7 +118,8 @@ public boolean equals(Object o) {
lastInsertId == okMessage.lastInsertId &&
serverStatuses == okMessage.serverStatuses &&
warnings == okMessage.warnings &&
information.equals(okMessage.information);
information.equals(okMessage.information) &&
systemVariables.equals(okMessage.systemVariables);
}

@Override
Expand All @@ -114,7 +129,8 @@ public int hashCode() {
result = 31 * result + (int) (lastInsertId ^ (lastInsertId >>> 32));
result = 31 * result + serverStatuses;
result = 31 * result + warnings;
return 31 * result + information.hashCode();
result = 31 * result + information.hashCode();
return 31 * result + systemVariables.hashCode();
}

@Override
Expand All @@ -124,15 +140,19 @@ public String toString() {
", affectedRows=" + Long.toUnsignedString(affectedRows) +
", lastInsertId=" + Long.toUnsignedString(lastInsertId) +
", serverStatuses=" + Integer.toHexString(serverStatuses) +
", information='" + information + "'}";
", information='" + information +
"', systemVariables=" + systemVariables +
'}';
}

return "OkMessage{isEndOfRows=" + isEndOfRows +
", affectedRows=" + Long.toUnsignedString(affectedRows) +
", lastInsertId=" + Long.toUnsignedString(lastInsertId) +
", serverStatuses=" + Integer.toHexString(serverStatuses) +
", warnings=" + warnings +
", information='" + information + "'}";
", information='" + information +
"', systemVariables=" + systemVariables +
"}";
}

static boolean isValidSize(int bytes) {
Expand Down Expand Up @@ -164,26 +184,79 @@ static OkMessage decode(boolean isEndOfRows, ByteBuf buf, ConnectionContext cont

if (sizeAfterVarInt < 0) {
return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses,
warnings, buf.toString(charset));
warnings, buf.toString(charset), Collections.emptyMap());
}

int oldReaderIndex = buf.readerIndex();
long infoSize = VarIntUtils.readVarInt(buf);

if (infoSize > sizeAfterVarInt) {
// Compatible code, the information may be an EOF encoded string at early versions of MySQL.
String info = buf.toString(oldReaderIndex, buf.writerIndex() - oldReaderIndex, charset);

return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses, warnings,
info, Collections.emptyMap());
}

int readerIndex = buf.readerIndex();
long size = VarIntUtils.readVarInt(buf);
String information;
// All the following have lengths should be less than Integer.MAX_VALUE
String information = buf.readCharSequence((int) infoSize, charset).toString();
Map<String, String> systemVariables = Collections.emptyMap();

while (VarIntUtils.checkNextVarInt(buf) >= 0) {
int stateInfoSize = (int) VarIntUtils.readVarInt(buf);
ByteBuf stateInfo = buf.readSlice(stateInfoSize);

while (stateInfo.isReadable()) {
if (stateInfo.readByte() == SESSION_TRACK_SYSTEM_VARIABLES) {
systemVariables = readServerVariables(stateInfo, context);
} else {
// Ignore other state info
int skipBytes = (int) VarIntUtils.readVarInt(stateInfo);

if (size > sizeAfterVarInt) {
information = buf.toString(readerIndex, buf.writerIndex() - readerIndex, charset);
} else {
// JVM does NOT support strings longer than Integer.MAX_VALUE
information = buf.toString(buf.readerIndex(), (int) size, charset);
stateInfo.skipBytes(skipBytes);
}
}
}

// Ignore session track, it is not human-readable and useless for R2DBC client.
return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses, warnings,
information);
information, systemVariables);
}

// Maybe have no human-readable message
return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses, warnings, "");
return new OkMessage(isEndOfRows, affectedRows, lastInsertId, serverStatuses, warnings, "",
Collections.emptyMap());
}

private static Map<String, String> readServerVariables(ByteBuf buf, ConnectionContext context) {
// All lengths should NOT be greater than Integer.MAX_VALUE
Map<String, String> map = new HashMap<>();
Charset charset = context.getClientCollation().getCharset();
int size = (int) VarIntUtils.readVarInt(buf);
ByteBuf sessionVar = buf.readSlice(size);

while (sessionVar.readableBytes() > 0) {
int variableSize = (int) VarIntUtils.readVarInt(sessionVar);
String variable = sessionVar.toString(sessionVar.readerIndex(), variableSize, charset);

sessionVar.skipBytes(variableSize);

int valueSize = (int) VarIntUtils.readVarInt(sessionVar);
String value = sessionVar.toString(sessionVar.readerIndex(), valueSize, charset);

sessionVar.skipBytes(valueSize);
map.put(variable, value);
}

switch (map.size()) {
case 0:
return Collections.emptyMap();
case 1: {
Map.Entry<String, String> entry = map.entrySet().iterator().next();
return Collections.singletonMap(entry.getKey(), entry.getValue());
}
default:
return map;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2024 asyncer.io projects
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.asyncer.r2dbc.mysql.message.server;

import io.asyncer.r2dbc.mysql.ConnectionContext;
import io.asyncer.r2dbc.mysql.ConnectionContextTest;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Unit tests for {@link OkMessage}.
*/
class OkMessageTest {

@Test
void decodeSessionVariables() {
boolean isMariaDb = "mariadb".equalsIgnoreCase(System.getProperty("test.db.type"));
ConnectionContext context = ConnectionContextTest.mock(isMariaDb);
OkMessage message = OkMessage.decode(true, sessionVariablesOk(), context);

assertThat(message.getAffectedRows()).isOne();
assertThat(message.getLastInsertId()).isEqualTo(2);
assertThat(message.getServerStatuses()).isEqualTo((short) 0x4000);
assertThat(message.getWarnings()).isEqualTo(3);
assertThat(message.getSystemVariable("autocommit")).isEqualTo("OFF");
}

private static ByteBuf sessionVariablesOk() {
return Unpooled.wrappedBuffer(new byte[] {
0,
1, 2, 0, 0x40, 3, 0, 0, 0x11, 0, 0xf, 0xa,
0x61, 0x75, 0x74, 0x6f, 0x63, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x3, 0x4f, 0x46, 0x46,
});
}
}

0 comments on commit fbdd900

Please sign in to comment.