From 5e2c3afd3d7b2ba05a7e814ed1de0576788968e2 Mon Sep 17 00:00:00 2001 From: rusher Date: Thu, 11 Jul 2024 16:21:51 +0200 Subject: [PATCH] [misc] OkPacket fast path --- .../jdbc/client/impl/StandardClient.java | 2 +- .../mariadb/jdbc/message/ClientMessage.java | 2 +- .../mariadb/jdbc/message/server/OkPacket.java | 111 +++++++++++++++++- .../jdbc/integration/StatementTest.java | 1 - 4 files changed, 107 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/mariadb/jdbc/client/impl/StandardClient.java b/src/main/java/org/mariadb/jdbc/client/impl/StandardClient.java index 5b84dc09b..dbaeaf019 100644 --- a/src/main/java/org/mariadb/jdbc/client/impl/StandardClient.java +++ b/src/main/java/org/mariadb/jdbc/client/impl/StandardClient.java @@ -311,7 +311,7 @@ public void authenticationHandler(Credential credential, HostAddress hostAddress // OK_Packet -> Authenticated ! // see https://mariadb.com/kb/en/library/ok_packet/ // ************************************************************************************* - OkPacket okPacket = new OkPacket(buf, context); + OkPacket okPacket = OkPacket.parseWithInfo(buf, context); // ssl certificates validation using client password if (certFingerprint != null) { diff --git a/src/main/java/org/mariadb/jdbc/message/ClientMessage.java b/src/main/java/org/mariadb/jdbc/message/ClientMessage.java index 67f6e7a7f..6137d8be8 100644 --- a/src/main/java/org/mariadb/jdbc/message/ClientMessage.java +++ b/src/main/java/org/mariadb/jdbc/message/ClientMessage.java @@ -169,7 +169,7 @@ default Completion readPacket( // * OK response // ********************************************************************************************************* case (byte) 0x00: - OkPacket ok = new OkPacket(buf, context); + OkPacket ok = OkPacket.parse(buf, context); if (context.getRedirectUrl() != null && (context.getServerStatus() & ServerStatus.IN_TRANSACTION) == 0 && (context.getServerStatus() & ServerStatus.MORE_RESULTS_EXISTS) == 0) { diff --git a/src/main/java/org/mariadb/jdbc/message/server/OkPacket.java b/src/main/java/org/mariadb/jdbc/message/server/OkPacket.java index 9cc66dc71..c2859d971 100644 --- a/src/main/java/org/mariadb/jdbc/message/server/OkPacket.java +++ b/src/main/java/org/mariadb/jdbc/message/server/OkPacket.java @@ -13,25 +13,125 @@ /** Ok packet parser see https://mariadb.com/kb/en/ok_packet/ */ public class OkPacket implements Completion { + private static final OkPacket BASIC_OK = new OkPacket(0, 0, null); private static final Logger logger = Loggers.getLogger(OkPacket.class); private final long affectedRows; private final long lastInsertId; private final byte[] info; + private OkPacket(long affectedRows, long lastInsertId, byte[] info) { + this.affectedRows = affectedRows; + this.lastInsertId = lastInsertId; + this.info = info; + } + /** * Parser * * @param buf packet buffer * @param context connection context */ - public OkPacket(ReadableByteBuf buf, Context context) { + public static OkPacket parse(ReadableByteBuf buf, Context context) { buf.skip(); // ok header - this.affectedRows = buf.readLongLengthEncodedNotNull(); - this.lastInsertId = buf.readLongLengthEncodedNotNull(); + long affectedRows = buf.readLongLengthEncodedNotNull(); + long lastInsertId = buf.readLongLengthEncodedNotNull(); context.setServerStatus(buf.readUnsignedShort()); context.setWarning(buf.readUnsignedShort()); + if (buf.readableBytes() > 0) { + buf.skip(buf.readIntLengthEncodedNotNull()); // skip info + if (context.hasClientCapability(Capabilities.CLIENT_SESSION_TRACK)) { + while (buf.readableBytes() > 0) { + ReadableByteBuf sessionStateBuf = buf.readLengthBuffer(); + while (sessionStateBuf.readableBytes() > 0) { + switch (sessionStateBuf.readByte()) { + case StateChange.SESSION_TRACK_SYSTEM_VARIABLES: + ReadableByteBuf tmpBufsv; + do { + tmpBufsv = sessionStateBuf.readLengthBuffer(); + String variableSv = tmpBufsv.readString(tmpBufsv.readIntLengthEncodedNotNull()); + Integer lenSv = tmpBufsv.readLength(); + String valueSv = lenSv == null ? null : tmpBufsv.readString(lenSv); + logger.debug("System variable change: {} = {}", variableSv, valueSv); + switch (variableSv) { + case "character_set_client": + context.setCharset(valueSv); + break; + case "connection_id": + context.setThreadId(Long.parseLong(valueSv)); + break; + case "threads_Connected": + context.setTreadsConnected(Long.parseLong(valueSv)); + break; + case "auto_increment_increment": + context.setAutoIncrement(Long.parseLong(valueSv)); + break; + case "redirect_url": + if (!"".equals(valueSv)) context.setRedirectUrl(valueSv); + break; + case "tx_isolation": + case "transaction_isolation": + switch (valueSv) { + case "REPEATABLE-READ": + context.setTransactionIsolationLevel( + java.sql.Connection.TRANSACTION_REPEATABLE_READ); + break; + case "READ-UNCOMMITTED": + context.setTransactionIsolationLevel( + java.sql.Connection.TRANSACTION_READ_UNCOMMITTED); + break; + case "READ-COMMITTED": + context.setTransactionIsolationLevel( + java.sql.Connection.TRANSACTION_READ_COMMITTED); + break; + case "SERIALIZABLE": + context.setTransactionIsolationLevel( + java.sql.Connection.TRANSACTION_SERIALIZABLE); + break; + default: + context.setTransactionIsolationLevel(null); + break; + } + break; + } + } while (tmpBufsv.readableBytes() > 0); + break; + + case StateChange.SESSION_TRACK_SCHEMA: + sessionStateBuf.readIntLengthEncodedNotNull(); + Integer dbLen = sessionStateBuf.readLength(); + String database = + dbLen == null || dbLen == 0 ? null : sessionStateBuf.readString(dbLen); + context.setDatabase(database); + logger.debug("Database change: is '{}'", database); + break; + + default: + buf.skip(buf.readIntLengthEncodedNotNull()); + break; + } + } + } + } + } + if (affectedRows == 0 && lastInsertId == 0) return BASIC_OK; + return new OkPacket(affectedRows, lastInsertId, null); + } + + /** + * Parser + * + * @param buf packet buffer + * @param context connection context + */ + public static OkPacket parseWithInfo(ReadableByteBuf buf, Context context) { + buf.skip(); // ok header + long affectedRows = buf.readLongLengthEncodedNotNull(); + long lastInsertId = buf.readLongLengthEncodedNotNull(); + context.setServerStatus(buf.readUnsignedShort()); + context.setWarning(buf.readUnsignedShort()); + byte[] info; if (buf.readableBytes() > 0) { info = new byte[buf.readIntLengthEncodedNotNull()]; buf.readBytes(info); @@ -108,9 +208,8 @@ public OkPacket(ReadableByteBuf buf, Context context) { } } } - } else { - info = null; - } + } else info = new byte[0]; + return new OkPacket(affectedRows, lastInsertId, info); } /** diff --git a/src/test/java/org/mariadb/jdbc/integration/StatementTest.java b/src/test/java/org/mariadb/jdbc/integration/StatementTest.java index a82c9a926..701b2217d 100644 --- a/src/test/java/org/mariadb/jdbc/integration/StatementTest.java +++ b/src/test/java/org/mariadb/jdbc/integration/StatementTest.java @@ -643,7 +643,6 @@ public void maxRows() throws SQLException { assertEquals(i, rs.getInt(1)); } assertEquals(10, i); - } }