diff --git a/README.md b/README.md index 04c5ec123..d39cf40e7 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ Refer to the table below to determine the appropriate version of r2dbc-mysql for This driver provides the following features: - [x] Unix domain socket. +- [x] Compression protocols, including zstd and zlib. - [x] Execution of simple or batch statements without bindings. - [x] Execution of prepared statements with bindings. - [x] Reactive LOB types (e.g. BLOB, CLOB) @@ -143,6 +144,7 @@ ConnectionFactoryOptions options = ConnectionFactoryOptions.builder() .option(Option.valueOf("allowLoadLocalInfileInPath"), "/opt") // optional, default null, null means LOCAL INFILE not be allowed (since 1.1.0) .option(Option.valueOf("tcpKeepAlive"), true) // optional, default false .option(Option.valueOf("tcpNoDelay"), true) // optional, default false + .option(Option.valueOf("compressionAlgorithms"), "zstd") // optional, default UNCOMPRESSED .option(Option.valueOf("autodetectExtensions"), false) // optional, default false .option(Option.valueOf("passwordPublisher"), Mono.just("password")) // optional, default null, null means has no passwordPublisher (since 1.0.5 / 0.9.6) .build(); @@ -191,6 +193,7 @@ MySqlConnectionConfiguration configuration = MySqlConnectionConfiguration.builde .allowLoadLocalInfileInPath("/opt") // optional, default null, null means LOCAL INFILE not be allowed .tcpKeepAlive(true) // optional, controls TCP Keep Alive, default is false .tcpNoDelay(true) // optional, controls TCP No Delay, default is false + .compressionAlgorithms(CompressionAlgorithm.ZSTD, CompressionAlgotihm.ZLIB) // optional, default is UNCOMPRESSED .autodetectExtensions(false) // optional, controls extension auto-detect, default is true .extendWith(MyExtension.INSTANCE) // optional, manual extend an extension into extensions, default using auto-detect .passwordPublisher(Mono.just("password")) // optional, default null, null means has no password publisher (since 1.0.5 / 0.9.6) @@ -242,6 +245,7 @@ Mono connectionMono = Mono.from(connectionFactory.create()); | autodetectExtensions | `true` or `false` | Optional, default is `true` | Controls auto-detect `Extension`s | | useServerPrepareStatement | `true`, `false` or `Predicate` | Optional, default is `false` | See following notice | | allowLoadLocalInfileInPath | A path | Optional, default is `null` | The path that allows `LOAD DATA LOCAL INFILE` to load file data | +| compressionAlgorithms | A list of `CompressionAlgorithm` | Optional, default is `UNCOMPRESSED` | The compression algorithms for MySQL connection | | passwordPublisher | A `Publisher` | Optional, default is `null` | The password publisher, see following notice | - `SslMode` Considers security level and verification for SSL, make sure the database server supports SSL before you want change SSL mode to `REQUIRED` or higher. **The Unix Domain Socket only offers "DISABLED" available** @@ -269,6 +273,11 @@ Mono connectionMono = Mono.from(connectionFactory.create()); - The `Extensions` will not remove duplicates, make sure it would be not extended twice or more - The auto-detected `Extension`s will not affect manual extends and will not remove duplicates - `passwordPublisher` Every time the client attempts to authenticate, it will use the password provided by the `passwordPublisher`.(Since `1.0.5` / `0.9.6`) e.g., You can employ this method for IAM-based authentication when connecting to an AWS Aurora RDS database. +- `compressionAlgorithms` Considers compression protocol for MySQL connection, it is **NOT** RECOMMENDED to use compression protocol in the general case, because it will increase the CPU usage and decrease the performance. + - `UNCOMPRESSED` (default) No compression + - `ZLIB` Use Zlib compression protocol, it is available on almost all MySQL versions (`5.x` and above) + - `ZSTD` Use Z-standard compression protocol, it is available since MySQL `8.0.18` or above, requires an extern dependency `com.github.luben:zstd-jni` + - For scenarios where the network environment is poor or the amount of data is always large, using a compression protocol may be useful Should use `enum` in [Programmatic](#programmatic-configuration) configuration that not like discovery configurations, except `TlsVersions` (All elements of `TlsVersions` will be always `String` which is case-sensitive). diff --git a/pom.xml b/pom.xml index 233ad4756..286b61581 100644 --- a/pom.xml +++ b/pom.xml @@ -78,6 +78,7 @@ 2.16.0 0.3.0.RELEASE 3.0.2 + 1.5.5-11 24.1.0 1.77 @@ -153,6 +154,13 @@ provided + + com.github.luben + zstd-jni + ${zstd-jni.version} + true + + ch.qos.logback logback-classic diff --git a/src/main/java/io/asyncer/r2dbc/mysql/Capability.java b/src/main/java/io/asyncer/r2dbc/mysql/Capability.java index 1bff4247f..67ebc3711 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/Capability.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/Capability.java @@ -158,7 +158,7 @@ public final class Capability { // Allow the server not to send column metadata in result set, // should NEVER enable this option. // private static final long OPTIONAL_RESULT_SET_METADATA = 1L << 25; -// private static final long Z_STD_COMPRESSION = 1L << 26; + private static final long ZSTD_COMPRESS = 1L << 26; // A reserved flag, used to extend the 32-bits capability bitmap to 64-bits. // There is no available MySql server version/edition to support it. @@ -175,7 +175,7 @@ public final class Capability { private static final long ALL_SUPPORTED = CLIENT_MYSQL | FOUND_ROWS | LONG_FLAG | CONNECT_WITH_DB | NO_SCHEMA | COMPRESS | LOCAL_FILES | IGNORE_SPACE | PROTOCOL_41 | INTERACTIVE | SSL | TRANSACTIONS | SECURE_SALT | MULTI_STATEMENTS | MULTI_RESULTS | PS_MULTI_RESULTS | - PLUGIN_AUTH | CONNECT_ATTRS | VAR_INT_SIZED_AUTH | SESSION_TRACK | DEPRECATE_EOF; + PLUGIN_AUTH | CONNECT_ATTRS | VAR_INT_SIZED_AUTH | SESSION_TRACK | DEPRECATE_EOF | ZSTD_COMPRESS; private final long bitmap; @@ -278,6 +278,33 @@ public boolean isTransactionAllowed() { return (bitmap & TRANSACTIONS) != 0; } + /** + * Checks if any compression enabled. + * + * @return if any compression enabled. + */ + public boolean isCompression() { + return (bitmap & (COMPRESS | ZSTD_COMPRESS)) != 0; + } + + /** + * Checks if zlib compression enabled. + * + * @return if zlib compression enabled. + */ + public boolean isZlibCompression() { + return (bitmap & COMPRESS) != 0; + } + + /** + * Checks if zstd compression enabled. + * + * @return if zstd compression enabled. + */ + public boolean isZstdCompression() { + return (bitmap & ZSTD_COMPRESS) != 0; + } + /** * Extends MariaDB capabilities. * @@ -362,9 +389,17 @@ void disableDatabasePinned() { } void disableCompression() { + this.bitmap &= ~(COMPRESS | ZSTD_COMPRESS); + } + + void disableZlibCompression() { this.bitmap &= ~COMPRESS; } + void disableZstdCompression() { + this.bitmap &= ~ZSTD_COMPRESS; + } + void disableLoadDataLocalInfile() { this.bitmap &= ~LOCAL_FILES; } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java index 0eec8645c..83f11b4cf 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java @@ -16,6 +16,7 @@ package io.asyncer.r2dbc.mysql; +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; import io.asyncer.r2dbc.mysql.extension.Extension; @@ -30,9 +31,12 @@ import java.time.Duration; import java.time.ZoneId; import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; import java.util.List; import java.util.Objects; import java.util.ServiceLoader; +import java.util.Set; import java.util.function.Function; import java.util.function.Predicate; @@ -97,6 +101,10 @@ public final class MySqlConnectionConfiguration { private final int prepareCacheSize; + private final Set compressionAlgorithms; + + private final int zstdCompressionLevel; + private final Extensions extensions; @Nullable @@ -109,8 +117,9 @@ private MySqlConnectionConfiguration( String user, @Nullable CharSequence password, @Nullable String database, boolean createDatabaseIfNotExist, @Nullable Predicate preferPrepareStatement, @Nullable Path loadLocalInfilePath, int localInfileBufferSize, - int queryCacheSize, int prepareCacheSize, Extensions extensions, - @Nullable Publisher passwordPublisher + int queryCacheSize, int prepareCacheSize, + Set compressionAlgorithms, int zstdCompressionLevel, + Extensions extensions, @Nullable Publisher passwordPublisher ) { this.isHost = isHost; this.domain = domain; @@ -130,6 +139,8 @@ private MySqlConnectionConfiguration( this.localInfileBufferSize = localInfileBufferSize; this.queryCacheSize = queryCacheSize; this.prepareCacheSize = prepareCacheSize; + this.compressionAlgorithms = compressionAlgorithms; + this.zstdCompressionLevel = zstdCompressionLevel; this.extensions = extensions; this.passwordPublisher = passwordPublisher; } @@ -220,6 +231,14 @@ int getPrepareCacheSize() { return prepareCacheSize; } + Set getCompressionAlgorithms() { + return compressionAlgorithms; + } + + int getZstdCompressionLevel() { + return zstdCompressionLevel; + } + Extensions getExtensions() { return extensions; } @@ -256,6 +275,8 @@ public boolean equals(Object o) { localInfileBufferSize == that.localInfileBufferSize && queryCacheSize == that.queryCacheSize && prepareCacheSize == that.prepareCacheSize && + compressionAlgorithms.equals(that.compressionAlgorithms) && + zstdCompressionLevel == that.zstdCompressionLevel && extensions.equals(that.extensions) && Objects.equals(passwordPublisher, that.passwordPublisher); } @@ -265,7 +286,7 @@ public int hashCode() { return Objects.hash(isHost, domain, port, ssl, tcpKeepAlive, tcpNoDelay, connectTimeout, serverZoneId, zeroDateOption, user, password, database, createDatabaseIfNotExist, preferPrepareStatement, loadLocalInfilePath, localInfileBufferSize, queryCacheSize, - prepareCacheSize, extensions, passwordPublisher); + prepareCacheSize, compressionAlgorithms, zstdCompressionLevel, extensions, passwordPublisher); } @Override @@ -280,6 +301,8 @@ public String toString() { ", loadLocalInfilePath=" + loadLocalInfilePath + ", localInfileBufferSize=" + localInfileBufferSize + ", queryCacheSize=" + queryCacheSize + ", prepareCacheSize=" + prepareCacheSize + + ", compressionAlgorithms=" + compressionAlgorithms + + ", zstdCompressionLevel=" + zstdCompressionLevel + ", extensions=" + extensions + ", passwordPublisher=" + passwordPublisher + '}'; } @@ -291,8 +314,10 @@ public String toString() { ", loadLocalInfilePath=" + loadLocalInfilePath + ", localInfileBufferSize=" + localInfileBufferSize + ", queryCacheSize=" + queryCacheSize + - ", prepareCacheSize=" + prepareCacheSize + ", extensions=" + extensions + - ", passwordPublisher=" + passwordPublisher + '}'; + ", prepareCacheSize=" + prepareCacheSize + + ", compressionAlgorithms=" + compressionAlgorithms + + ", zstdCompressionLevel=" + zstdCompressionLevel + + ", extensions=" + extensions + ", passwordPublisher=" + passwordPublisher + '}'; } /** @@ -363,6 +388,11 @@ public static final class Builder { private int prepareCacheSize = 256; + private Set compressionAlgorithms = + Collections.singleton(CompressionAlgorithm.UNCOMPRESSED); + + private int zstdCompressionLevel = 3; + private boolean autodetectExtensions = true; private final List extensions = new ArrayList<>(); @@ -395,6 +425,7 @@ public MySqlConnectionConfiguration build() { connectTimeout, zeroDateOption, serverZoneId, user, password, database, createDatabaseIfNotExist, preferPrepareStatement, loadLocalInfilePath, localInfileBufferSize, queryCacheSize, prepareCacheSize, + compressionAlgorithms, zstdCompressionLevel, Extensions.from(extensions, autodetectExtensions), passwordPublisher); } @@ -822,6 +853,64 @@ public Builder prepareCacheSize(int prepareCacheSize) { return this; } + /** + * Configures the compression algorithms. Default to [{@link CompressionAlgorithm#UNCOMPRESSED}]. + *

+ * It will auto choose an algorithm that's contained in the list and supported by the server, + * preferring zstd, then zlib. If the list does not contain {@link CompressionAlgorithm#UNCOMPRESSED} + * and the server does not support any algorithm in the list, an exception will be thrown when + * connecting. + *

+ * Note: zstd requires a dependency {@code com.github.luben:zstd-jni}. + * + * @param compressionAlgorithms the list of compression algorithms. + * @return {@link Builder this}. + * @throws IllegalArgumentException if {@code compressionAlgorithms} is {@code null} or empty. + * @since 1.1.2 + */ + public Builder compressionAlgorithms(CompressionAlgorithm... compressionAlgorithms) { + requireNonNull(compressionAlgorithms, "compressionAlgorithms must not be null"); + require(compressionAlgorithms.length != 0, "compressionAlgorithms must not be empty"); + + if (compressionAlgorithms.length == 1) { + requireNonNull(compressionAlgorithms[0], "compressionAlgorithms must not contain null"); + this.compressionAlgorithms = Collections.singleton(compressionAlgorithms[0]); + } else { + Set algorithms = EnumSet.noneOf(CompressionAlgorithm.class); + + for (CompressionAlgorithm algorithm : compressionAlgorithms) { + requireNonNull(algorithm, "compressionAlgorithms must not contain null"); + algorithms.add(algorithm); + } + + this.compressionAlgorithms = algorithms; + } + + return this; + } + + /** + * Configures the zstd compression level. Default to {@code 3}. + *

+ * It is only used if zstd is chosen for the connection. + *

+ * Note: MySQL protocol does not allow to set the zlib compression level of the server, only zstd is + * configurable. + * + * @param level the compression level. + * @return {@link Builder this}. + * @throws IllegalArgumentException if {@code level} is not between 1 and 22. + * @since 1.1.2 + * @see + * MySQL Connection Options --zstd-compression-level + */ + public Builder zstdCompressionLevel(int level) { + require(level >= 1 && level <= 22, "level must be between 1 and 22"); + + this.zstdCompressionLevel = level; + return this; + } + /** * Configures whether to use {@link ServiceLoader} to discover and register extensions. Defaults to * {@code true}. diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java index 6070d195c..4c2d69cd4 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java @@ -22,6 +22,7 @@ import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.codec.Codecs; import io.asyncer.r2dbc.mysql.codec.CodecsBuilder; +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.extension.CodecRegistrar; import io.netty.buffer.ByteBufAllocator; @@ -35,6 +36,7 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.Objects; +import java.util.Set; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Predicate; @@ -90,12 +92,14 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura String user = configuration.getUser(); CharSequence password = configuration.getPassword(); SslMode sslMode = ssl.getSslMode(); + int zstdCompressionLevel = configuration.getZstdCompressionLevel(); ConnectionContext context = new ConnectionContext( configuration.getZeroDateOption(), configuration.getLoadLocalInfilePath(), configuration.getLocalInfileBufferSize(), configuration.getServerZoneId() ); + Set compressionAlgorithms = configuration.getCompressionAlgorithms(); Extensions extensions = configuration.getExtensions(); Predicate prepare = configuration.getPreferPrepareStatement(); int prepareCacheSize = configuration.getPrepareCacheSize(); @@ -106,8 +110,9 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura configuration, queryCache, ssl, address, database, createDbIfNotExist, - user, sslMode, context, - extensions, prepare, + user, sslMode, + compressionAlgorithms, zstdCompressionLevel, + context, extensions, prepare, prepareCacheSize, token )); } @@ -116,8 +121,9 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura configuration, queryCache, ssl, address, database, createDbIfNotExist, - user, sslMode, context, - extensions, prepare, + user, sslMode, + compressionAlgorithms, zstdCompressionLevel, + context, extensions, prepare, prepareCacheSize, password ); })); @@ -132,6 +138,8 @@ private static Mono getMySqlConnection( final boolean createDbIfNotExist, final String user, final SslMode sslMode, + final Set compressionAlgorithms, + final int zstdCompressionLevel, final ConnectionContext context, final Extensions extensions, @Nullable final Predicate prepare, @@ -142,7 +150,8 @@ private static Mono getMySqlConnection( .flatMap(client -> { // Lazy init database after handshake/login String db = createDbIfNotExist ? "" : database; - return QueryFlow.login(client, sslMode, db, user, password, context); + return QueryFlow.login(client, sslMode, db, user, password, compressionAlgorithms, + zstdCompressionLevel, context); }) .flatMap(client -> { ByteBufAllocator allocator = client.getByteBufAllocator(); diff --git a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java index eba16df09..a0e818664 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java @@ -16,6 +16,7 @@ package io.asyncer.r2dbc.mysql; +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; import io.netty.handler.ssl.SslContextBuilder; @@ -181,9 +182,41 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr public static final Option USE_SERVER_PREPARE_STATEMENT = Option.valueOf("useServerPrepareStatement"); + /** + * Option to set the allowed local infile path. + * + * @since 1.1.0 + */ public static final Option ALLOW_LOAD_LOCAL_INFILE_IN_PATH = Option.valueOf("allowLoadLocalInfileInPath"); + /** + * Option to set compression algorithms. Default to [{@link CompressionAlgorithm#UNCOMPRESSED}]. + *

+ * It will auto choose an algorithm that's contained in the list and supported by the server, preferring + * zstd, then zlib. If the list does not contain {@link CompressionAlgorithm#UNCOMPRESSED} and the server + * does not support any algorithm in the list, an exception will be thrown when connecting. + *

+ * Note: zstd requires a dependency {@code com.github.luben:zstd-jni}. + * + * @since 1.1.2 + */ + public static final Option COMPRESSION_ALGORITHMS = + Option.valueOf("compressionAlgorithms"); + + /** + * Option to set the zstd compression level. Default to {@code 3}. + *

+ * It is only used if zstd is chosen for the connection. + *

+ * Note: MySQL protocol does not allow to set the zlib compression level of the server, only zstd is + * configurable. + * + * @since 1.1.2 + */ + public static final Option ZSTD_COMPRESSION_LEVEL = + Option.valueOf("zstdCompressionLevel"); + /** * Option to set the maximum size of the {@link Query} parsing cache. Default to {@code 256}. * @@ -206,10 +239,9 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr public static final Option AUTODETECT_EXTENSIONS = Option.valueOf("autodetectExtensions"); /** - * Password Publisher function can be used to retrieve password before creating a connection. - * This can be used with Amazon RDS Aurora IAM authentication, wherein it requires token to be generated. - * The token is valid for 15 minutes, and this token will be used as password. - * + * Password Publisher function can be used to retrieve password before creating a connection. This can be + * used with Amazon RDS Aurora IAM authentication, wherein it requires token to be generated. The token is + * valid for 15 minutes, and this token will be used as password. */ public static final Option> PASSWORD_PUBLISHER = Option.valueOf("passwordPublisher"); @@ -273,6 +305,13 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) { .to(builder::database); mapper.optional(CREATE_DATABASE_IF_NOT_EXIST).asBoolean() .to(builder::createDatabaseIfNotExist); + mapper.optional(COMPRESSION_ALGORITHMS).asArray( + CompressionAlgorithm[].class, + it -> CompressionAlgorithm.valueOf(it.toUpperCase()), + CompressionAlgorithm[]::new + ).to(builder::compressionAlgorithms); + mapper.optional(ZSTD_COMPRESSION_LEVEL).asInt() + .to(builder::zstdCompressionLevel); mapper.optional(PASSWORD_PUBLISHER).as(Publisher.class) .to(builder::passwordPublisher); @@ -295,7 +334,7 @@ private static void setupHost(MySqlConnectionConfiguration.Builder builder, Opti .to(isSsl -> builder.sslMode(isSsl ? SslMode.REQUIRED : SslMode.DISABLED)); mapper.optional(SSL_MODE).as(SslMode.class, id -> SslMode.valueOf(id.toUpperCase())) .to(builder::sslMode); - mapper.optional(TLS_VERSION).asStrings() + mapper.optional(TLS_VERSION).asArray(String[].class, Function.identity(), String[]::new) .to(builder::tlsVersion); mapper.optional(SSL_HOSTNAME_VERIFIER).as(HostnameVerifier.class) .to(builder::sslHostnameVerifier); diff --git a/src/main/java/io/asyncer/r2dbc/mysql/OptionMapper.java b/src/main/java/io/asyncer/r2dbc/mysql/OptionMapper.java index 218cc11df..8a95cea1d 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/OptionMapper.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/OptionMapper.java @@ -23,6 +23,7 @@ import java.util.Collection; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.IntFunction; import java.util.function.Predicate; /** @@ -105,21 +106,38 @@ Source as(Class type, Function mapping) { throw new IllegalArgumentException(toMessage(value, type.getTypeName())); } - Source asStrings() { + Source asArray(Class arrayType, Function mapper, IntFunction generator) { if (value == null) { return nilSource(); } - if (value instanceof String[]) { - return new Source<>((String[]) value); + if (arrayType.isInstance(value)) { + return new Source<>(arrayType.cast(value)); + } else if (value instanceof String[]) { + return new Source<>(mapArray((String[]) value, mapper, generator)); } else if (value instanceof String) { - return new Source<>(((String) value).split(",")); + String[] strings = ((String) value).split(","); + + if (arrayType.isInstance(strings)) { + return new Source<>(arrayType.cast(strings)); + } + + return new Source<>(mapArray(strings, mapper, generator)); } else if (value instanceof Collection) { - return new Source<>(((Collection) value).stream() - .map(String.class::cast).toArray(String[]::new)); + @SuppressWarnings("unchecked") + Class type = (Class) arrayType.getComponentType(); + R[] array = ((Collection) value).stream().map(e -> { + if (type.isInstance(e)) { + return type.cast(e); + } else { + return mapper.apply(e.toString()); + } + }).toArray(generator); + + return new Source<>(array); } - throw new IllegalArgumentException(toMessage(value, "String[]")); + throw new IllegalArgumentException(toMessage(value, arrayType.getTypeName())); } Source asBoolean() { @@ -236,6 +254,16 @@ private static Source nilSource() { private static String toMessage(Object value, String type) { return "Cannot convert value " + value + " to " + type; } + + private static O[] mapArray(String[] input, Function mapper, IntFunction generator) { + O[] output = generator.apply(input.length); + + for (int i = 0; i < input.length; i++) { + output[i] = mapper.apply(input[i]); + } + + return output; + } } enum Otherwise { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java index d8caa6a31..4ad301974 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java @@ -20,6 +20,7 @@ import io.asyncer.r2dbc.mysql.cache.PrepareCache; import io.asyncer.r2dbc.mysql.client.Client; import io.asyncer.r2dbc.mysql.client.FluxExchangeable; +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; import io.asyncer.r2dbc.mysql.constant.ServerStatuses; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.internal.util.StringUtils; @@ -56,6 +57,7 @@ import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.R2dbcNonTransientResourceException; import io.r2dbc.spi.R2dbcPermissionDeniedException; import io.r2dbc.spi.TransactionDefinition; import org.jetbrains.annotations.Nullable; @@ -68,12 +70,15 @@ import reactor.core.publisher.SynchronousSink; import reactor.util.concurrent.Queues; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; @@ -191,17 +196,22 @@ static Flux> execute(Client client, List statements) * Login a {@link Client} and receive the {@code client} after logon. It will emit an exception when * client receives a {@link ErrorMessage}. * - * @param client the {@link Client} to exchange messages with. - * @param sslMode the {@link SslMode} defines SSL capability and behavior. - * @param database the database that will be connected. - * @param user the user that will be login. - * @param password the password of the {@code user}. - * @param context the {@link ConnectionContext} for initialization. + * @param client the {@link Client} to exchange messages with. + * @param sslMode the {@link SslMode} defines SSL capability and behavior. + * @param database the database that will be connected. + * @param user the user that will be login. + * @param password the password of the {@code user}. + * @param compressionAlgorithms the list of compression algorithms. + * @param zstdCompressionLevel the zstd compression level. + * @param context the {@link ConnectionContext} for initialization. * @return the messages received in response to the login exchange. */ static Mono login(Client client, SslMode sslMode, String database, String user, - @Nullable CharSequence password, ConnectionContext context) { - return client.exchange(new LoginExchangeable(client, sslMode, database, user, password, context)) + @Nullable CharSequence password, + Set compressionAlgorithms, int zstdCompressionLevel, + ConnectionContext context) { + return client.exchange(new LoginExchangeable(client, sslMode, database, user, password, + compressionAlgorithms, zstdCompressionLevel, context)) .onErrorResume(e -> client.forceClose().then(Mono.error(e))) .then(Mono.just(client)); } @@ -327,7 +337,7 @@ public final void accept(ServerMessage message, SynchronousSink s QueryLogger.logLocalInfile(path); requests.emitNext( - new LocalInfileResponse(request.getEnvelopeId() + 1, path, sink), + new LocalInfileResponse(path, sink), Sinks.EmitFailureHandler.FAIL_FAST ); } else { @@ -828,6 +838,10 @@ final class LoginExchangeable extends FluxExchangeable { @Nullable private final CharSequence password; + private final Set compressions; + + private final int zstdCompressionLevel; + private final ConnectionContext context; private boolean handshake = true; @@ -838,15 +852,16 @@ final class LoginExchangeable extends FluxExchangeable { private boolean sslCompleted; - private int lastEnvelopeId; - LoginExchangeable(Client client, SslMode sslMode, String database, String user, - @Nullable CharSequence password, ConnectionContext context) { + @Nullable CharSequence password, Set compressions, + int zstdCompressionLevel, ConnectionContext context) { this.client = client; this.sslMode = sslMode; this.database = database; this.user = user; this.password = password; + this.compressions = compressions; + this.zstdCompressionLevel = zstdCompressionLevel; this.context = context; this.sslCompleted = sslMode == SslMode.TUNNEL; } @@ -870,13 +885,10 @@ public void accept(ServerMessage message, SynchronousSink sink) { HandshakeRequest request = (HandshakeRequest) message; Capability capability = initHandshake(request); - lastEnvelopeId = request.getEnvelopeId() + 1; - if (capability.isSslEnabled()) { - emitNext(SslRequest.from(lastEnvelopeId, capability, - context.getClientCollation().getId()), sink); + emitNext(SslRequest.from(capability, context.getClientCollation().getId()), sink); } else { - emitNext(createHandshakeResponse(lastEnvelopeId, capability), sink); + emitNext(createHandshakeResponse(capability), sink); } } else { sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" + @@ -891,10 +903,9 @@ public void accept(ServerMessage message, SynchronousSink sink) { sink.complete(); } else if (message instanceof SyntheticSslResponseMessage) { sslCompleted = true; - emitNext(createHandshakeResponse(++lastEnvelopeId, context.getCapability()), sink); + emitNext(createHandshakeResponse(context.getCapability()), sink); } else if (message instanceof AuthMoreDataMessage) { AuthMoreDataMessage msg = (AuthMoreDataMessage) message; - lastEnvelopeId = msg.getEnvelopeId() + 1; if (msg.isFailed()) { if (logger.isDebugEnabled()) { @@ -902,15 +913,15 @@ public void accept(ServerMessage message, SynchronousSink sink) { context.getConnectionId()); } - emitNext(createAuthResponse(lastEnvelopeId, "full authentication"), sink); + emitNext(createAuthResponse("full authentication"), sink); } // Otherwise success, wait until OK message or Error message. } else if (message instanceof ChangeAuthMessage) { ChangeAuthMessage msg = (ChangeAuthMessage) message; - lastEnvelopeId = msg.getEnvelopeId() + 1; + authProvider = MySqlAuthProvider.build(msg.getAuthType()); salt = msg.getSalt(); - emitNext(createAuthResponse(lastEnvelopeId, "change authentication"), sink); + emitNext(createAuthResponse("change authentication"), sink); } else { sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" + message.getClass().getSimpleName() + "' in login phase")); @@ -931,15 +942,14 @@ private void emitNext(SubsequenceClientMessage message, SynchronousSink si } } - private AuthResponse createAuthResponse(int envelopeId, String phase) { + private AuthResponse createAuthResponse(String phase) { MySqlAuthProvider authProvider = getAndNextProvider(); if (authProvider.isSslNecessary() && !sslCompleted) { throw new R2dbcPermissionDeniedException(authFails(authProvider.getType(), phase), CLI_SPECIFIC); } - return new AuthResponse(envelopeId, - authProvider.authentication(password, salt, context.getClientCollation())); + return new AuthResponse(authProvider.authentication(password, salt, context.getClientCollation())); } private Capability clientCapability(Capability serverCapability) { @@ -947,7 +957,6 @@ private Capability clientCapability(Capability serverCapability) { builder.disableSessionTrack(); builder.disableDatabasePinned(); - builder.disableCompression(); builder.disableIgnoreAmbiguitySpace(); builder.disableInteractiveTimeout(); @@ -970,6 +979,32 @@ private Capability clientCapability(Capability serverCapability) { } } + if (isZstdAllowed(serverCapability)) { + if (isZstdSupported()) { + builder.disableZlibCompression(); + } else { + logger.warn("Server supports zstd, but zstd-jni dependency is missing"); + + if (isZlibAllowed(serverCapability)) { + builder.disableZstdCompression(); + } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) { + builder.disableCompression(); + } else { + throw new R2dbcNonTransientResourceException( + "Environment does not support a compression algorithm in " + compressions + + ", config does not allow uncompressed mode", CLI_SPECIFIC); + } + } + } else if (isZlibAllowed(serverCapability)) { + builder.disableZstdCompression(); + } else if (compressions.contains(CompressionAlgorithm.UNCOMPRESSED)) { + builder.disableCompression(); + } else { + throw new R2dbcPermissionDeniedException( + "Environment does not support a compression algorithm in " + compressions + + ", config does not allow uncompressed mode", CLI_SPECIFIC); + } + if (database.isEmpty()) { builder.disableConnectWithDatabase(); } @@ -1011,7 +1046,7 @@ private MySqlAuthProvider getAndNextProvider() { return authProvider; } - private HandshakeResponse createHandshakeResponse(int envelopeId, Capability capability) { + private HandshakeResponse createHandshakeResponse(Capability capability) { MySqlAuthProvider authProvider = getAndNextProvider(); if (authProvider.isSslNecessary() && !sslCompleted) { @@ -1028,13 +1063,34 @@ private HandshakeResponse createHandshakeResponse(int envelopeId, Capability cap authType = MySqlAuthProvider.CACHING_SHA2_PASSWORD; } - return HandshakeResponse.from(envelopeId, capability, context.getClientCollation().getId(), - user, authorization, authType, database, ATTRIBUTES); + return HandshakeResponse.from(capability, context.getClientCollation().getId(), user, authorization, + authType, database, ATTRIBUTES, zstdCompressionLevel); + } + + private boolean isZstdAllowed(Capability capability) { + return capability.isZstdCompression() && compressions.contains(CompressionAlgorithm.ZSTD); + } + + private boolean isZlibAllowed(Capability capability) { + return capability.isZlibCompression() && compressions.contains(CompressionAlgorithm.ZLIB); } private static String authFails(String authType, String phase) { return "Authentication type '" + authType + "' must require SSL in " + phase + " phase"; } + + private static boolean isZstdSupported() { + try { + ClassLoader loader = AccessController.doPrivileged((PrivilegedAction) () -> { + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + return cl == null ? ClassLoader.getSystemClassLoader() : cl; + }); + Class.forName("com.github.luben.zstd.Zstd", false, loader); + return true; + } catch (ClassNotFoundException e) { + return false; + } + } } abstract class AbstractTransactionState { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/authentication/AuthUtils.java b/src/main/java/io/asyncer/r2dbc/mysql/authentication/AuthUtils.java index 43d18fd32..e3a952ee6 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/authentication/AuthUtils.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/authentication/AuthUtils.java @@ -22,7 +22,7 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; /** * An utility for general authentication hashing algorithm. diff --git a/src/main/java/io/asyncer/r2dbc/mysql/authentication/CachingSha2FastAuthProvider.java b/src/main/java/io/asyncer/r2dbc/mysql/authentication/CachingSha2FastAuthProvider.java index bf6919701..0d070dd00 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/authentication/CachingSha2FastAuthProvider.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/authentication/CachingSha2FastAuthProvider.java @@ -19,7 +19,7 @@ import io.asyncer.r2dbc.mysql.collation.CharCollation; import org.jetbrains.annotations.Nullable; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** diff --git a/src/main/java/io/asyncer/r2dbc/mysql/authentication/CachingSha2FullAuthProvider.java b/src/main/java/io/asyncer/r2dbc/mysql/authentication/CachingSha2FullAuthProvider.java index e00cf8292..52c92b969 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/authentication/CachingSha2FullAuthProvider.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/authentication/CachingSha2FullAuthProvider.java @@ -21,7 +21,7 @@ import java.nio.CharBuffer; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** diff --git a/src/main/java/io/asyncer/r2dbc/mysql/authentication/MySqlClearAuthProvider.java b/src/main/java/io/asyncer/r2dbc/mysql/authentication/MySqlClearAuthProvider.java index f2a17047c..cc90da4a2 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/authentication/MySqlClearAuthProvider.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/authentication/MySqlClearAuthProvider.java @@ -21,7 +21,7 @@ import java.nio.CharBuffer; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** diff --git a/src/main/java/io/asyncer/r2dbc/mysql/authentication/Sha256AuthProvider.java b/src/main/java/io/asyncer/r2dbc/mysql/authentication/Sha256AuthProvider.java index bf91b5921..6a88fdebd 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/authentication/Sha256AuthProvider.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/authentication/Sha256AuthProvider.java @@ -21,7 +21,7 @@ import java.nio.CharBuffer; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/CompressionDuplexCodec.java b/src/main/java/io/asyncer/r2dbc/mysql/client/CompressionDuplexCodec.java new file mode 100644 index 000000000..f71a85ba6 --- /dev/null +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/CompressionDuplexCodec.java @@ -0,0 +1,243 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.client; + +import io.asyncer.r2dbc.mysql.constant.Packets; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import org.jetbrains.annotations.Nullable; + +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A codec that compresses and decompresses packets. + *

    + *
  • Read: compression {@link ByteBuf} -> compression-framed {@link ByteBuf} -> + * decompressed {@link ByteBuf}
  • + *
  • Write: uncompressed-framed {@link ByteBuf} -> compression-framed {@link ByteBuf}
  • + *
+ */ +final class CompressionDuplexCodec extends ByteToMessageDecoder implements ChannelOutboundHandler { + + static final String NAME = "R2dbcMysqlCompressionDuplexCodec"; + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(CompressionDuplexCodec.class); + + private static final int MIN_COMPRESS_LENGTH = 50; + + /** + * Compression packet sequence id, incremented independently of the normal sequence id. + */ + private final AtomicInteger sequenceId = new AtomicInteger(0); + + private final Compressor compressor; + + @Nullable + private ByteBuf writeCumulated; + + private final Cumulator writeCumulator = MERGE_CUMULATOR; + + private int frameLength = -1; + + CompressionDuplexCodec(Compressor compressor) { + this.compressor = compressor; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg instanceof ByteBuf) { + ByteBuf cumulated = this.writeCumulated == null ? ctx.alloc().buffer(0, 0) : + this.writeCumulated; + + this.writeCumulated = cumulated = writeCumulator.cumulate(ctx.alloc(), cumulated, (ByteBuf) msg); + + while (cumulated.readableBytes() >= Packets.MAX_PAYLOAD_SIZE) { + logger.trace("Accumulated to the maximum payload, compressing"); + + ByteBuf slice = cumulated.readSlice(Packets.MAX_PAYLOAD_SIZE); + ByteBuf compressed = compressor.compress(slice); + + if (compressed.readableBytes() >= slice.readableBytes()) { + logger.trace("Sending uncompressed due to compressed payload is larger than original"); + compressed.release(); + ctx.write(buildHeader(ctx, slice.readableBytes(), 0)); + ctx.write(slice.retain()); + } else { + logger.trace("Sending compressed payload"); + ctx.write(buildHeader(ctx, compressed.readableBytes(), Packets.MAX_PAYLOAD_SIZE)); + ctx.write(compressed); + } + } + + if (!cumulated.isReadable()) { + this.writeCumulated = null; + cumulated.release(); + } else { + logger.trace("Accumulated writing buffers, waiting for flush"); + } + } else { + ctx.write(msg, promise); + } + } + + private ByteBuf buildHeader(ChannelHandlerContext ctx, int compressedSize, int uncompressedSize) { + return ctx.alloc().ioBuffer(Packets.COMPRESS_HEADER_SIZE) + .writeMediumLE(compressedSize) + .writeByte(sequenceId.getAndIncrement()) + .writeMediumLE(uncompressedSize); + } + + @Override + public void flush(ChannelHandlerContext ctx) { + ByteBuf cumulated = this.writeCumulated; + + this.writeCumulated = null; + + if (cumulated == null) { + ctx.flush(); + return; + } + + int uncompressedSize = cumulated.readableBytes(); + + if (uncompressedSize < MIN_COMPRESS_LENGTH) { + logger.trace("flushing, payload is too small to compress, sending uncompressed"); + ctx.write(buildHeader(ctx, uncompressedSize, 0)); + ctx.writeAndFlush(cumulated); + } else { + try { + logger.trace("flushing, compressing payload"); + + ByteBuf compressed = compressor.compress(cumulated); + + if (compressed.readableBytes() >= uncompressedSize) { + logger.trace("Sending uncompressed due to compressed payload is larger than original"); + compressed.release(); + ctx.write(buildHeader(ctx, uncompressedSize, 0)); + ctx.writeAndFlush(cumulated.retain()); + } else { + logger.trace("Sending compressed payload"); + ctx.write(buildHeader(ctx, compressed.readableBytes(), uncompressedSize)); + ctx.writeAndFlush(compressed); + } + } finally { + cumulated.release(); + } + } + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + ByteBuf frame = decode(in); + + if (frame != null) { + out.add(frame); + } + } + + @Nullable + private ByteBuf decode(ByteBuf in) { + if (frameLength == -1) { + // New frame + if (in.readableBytes() < Packets.SIZE_FIELD_SIZE) { + return null; + } + + frameLength = in.getUnsignedMediumLE(in.readerIndex()) + Packets.COMPRESS_HEADER_SIZE; + } + + if (in.readableBytes() < frameLength) { + return null; + } + + in.skipBytes(Packets.SIZE_FIELD_SIZE); + + int sequenceId = in.readUnsignedByte(); + int uncompressedSize = in.readUnsignedMediumLE(); + ByteBuf frame = in.readRetainedSlice(frameLength - Packets.COMPRESS_HEADER_SIZE); + + logger.trace("Decoded frame with sequence id: {}, total size: {}, uncompressed size: {}", + sequenceId, frameLength, uncompressedSize); + this.frameLength = -1; + this.sequenceId.set(sequenceId + 1); + + if (uncompressedSize == 0) { + return frame; + } else { + try { + return compressor.decompress(frame, uncompressedSize); + } finally { + frame.release(); + } + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (PacketEvent.RESET_SEQUENCE == evt) { + logger.debug("Reset sequence id"); + this.sequenceId.set(0); + } + + ctx.fireUserEventTriggered(evt); + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, + ChannelPromise promise) { + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + ctx.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { + ctx.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) { + ctx.read(); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) { + this.compressor.dispose(); + } +} diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/Compressor.java b/src/main/java/io/asyncer/r2dbc/mysql/client/Compressor.java new file mode 100644 index 000000000..ee7adbdf1 --- /dev/null +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/Compressor.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.client; + +import io.netty.buffer.ByteBuf; +import reactor.core.Disposable; + +/** + * An abstraction considers to compress and decompress data. + */ +interface Compressor extends Disposable { + + /** + * Compresses the given {@link ByteBuf}. It does not guarantee that the compressed data is smaller than + * the original. It will not change the reader index of the given {@link ByteBuf}. It may return early if + * the compressed data is not smaller than the original. + * + * @param buf the {@link ByteBuf} to compress + * @return the compressed {@link ByteBuf} + */ + ByteBuf compress(ByteBuf buf); + + /** + * Decompresses the given {@link ByteBuf}. + * + * @param buf the {@link ByteBuf} to decompress + * @param uncompressedSize the size of the uncompressed data + * @return the decompressed {@link ByteBuf} + */ + ByteBuf decompress(ByteBuf buf, int uncompressedSize); +} diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/EnvelopeSlicer.java b/src/main/java/io/asyncer/r2dbc/mysql/client/EnvelopeSlicer.java deleted file mode 100644 index 5e6c2bb51..000000000 --- a/src/main/java/io/asyncer/r2dbc/mysql/client/EnvelopeSlicer.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2023 asyncer.io projects - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.asyncer.r2dbc.mysql.client; - -import io.asyncer.r2dbc.mysql.constant.Envelopes; -import io.netty.buffer.ByteBuf; -import io.netty.handler.codec.DecoderException; -import io.netty.handler.codec.LengthFieldBasedFrameDecoder; - -import java.nio.ByteOrder; - -/** - * Slice server message envelope of MySQL protocol. - */ -final class EnvelopeSlicer extends LengthFieldBasedFrameDecoder { - - static final String NAME = "R2dbcMySqlEnvelopeSlicer"; - - EnvelopeSlicer() { - super(ByteOrder.LITTLE_ENDIAN, Envelopes.MAX_ENVELOPE_SIZE + Envelopes.PART_HEADER_SIZE, 0, - Envelopes.SIZE_FIELD_SIZE, - 1, // byte size of sequence Id field - 0, // do NOT strip header - true - ); - } - - /** - * Override this method because {@code ByteBuf.order(order)} will create temporary {@code SwappedByteBuf}, - * and {@code ByteBuf.order(order)} has also been deprecated. - *

- * {@inheritDoc} - */ - @Override - protected long getUnadjustedFrameLength(ByteBuf buf, int offset, int length, ByteOrder order) { - if (length != Envelopes.SIZE_FIELD_SIZE || order != ByteOrder.LITTLE_ENDIAN) { - // impossible length or order, only BUG or hack of reflect - throw new DecoderException("Unsupported lengthFieldLength: " + length + - " (only 3) or byteOrder: " + order + " (only LITTLE_ENDIAN)"); - } - - return buf.getUnsignedMediumLE(offset); - } -} diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/MessageDuplexCodec.java b/src/main/java/io/asyncer/r2dbc/mysql/client/MessageDuplexCodec.java index 1641e480c..09f231b67 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/client/MessageDuplexCodec.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/MessageDuplexCodec.java @@ -17,9 +17,9 @@ package io.asyncer.r2dbc.mysql.client; import io.asyncer.r2dbc.mysql.ConnectionContext; +import io.asyncer.r2dbc.mysql.constant.Packets; import io.asyncer.r2dbc.mysql.internal.util.OperatorUtils; import io.asyncer.r2dbc.mysql.message.client.ClientMessage; -import io.asyncer.r2dbc.mysql.message.client.SubsequenceClientMessage; import io.asyncer.r2dbc.mysql.message.client.PrepareQueryMessage; import io.asyncer.r2dbc.mysql.message.client.PreparedFetchMessage; import io.asyncer.r2dbc.mysql.message.client.SslRequest; @@ -34,51 +34,60 @@ import io.asyncer.r2dbc.mysql.message.server.SyntheticMetadataMessage; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import org.jetbrains.annotations.Nullable; import reactor.core.publisher.Flux; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** - * Client/server messages encode/decode logic. + * A codec that encodes and decodes MySQL messages. + *

    + *
  • Read: {@link ByteBuf} -> framed {@link ByteBuf} -> {@link ServerMessage}
  • + *
  • Write: {@link ClientMessage} -> framed {@link ByteBuf} with last flush
  • + *
*/ -final class MessageDuplexCodec extends ChannelDuplexHandler { +final class MessageDuplexCodec extends ByteToMessageDecoder implements ChannelOutboundHandler { static final String NAME = "R2dbcMySqlMessageDuplexCodec"; private static final InternalLogger logger = InternalLoggerFactory.getInstance(MessageDuplexCodec.class); + private final AtomicInteger sequenceId = new AtomicInteger(0); + private DecodeContext decodeContext = DecodeContext.login(); private final ConnectionContext context; private final ServerMessageDecoder decoder = new ServerMessageDecoder(); + private int frameLength = -1; + MessageDuplexCodec(ConnectionContext context) { this.context = requireNonNull(context, "context must not be null"); } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - if (msg instanceof ByteBuf) { + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + ByteBuf frame = decode(in); + + if (frame != null) { DecodeContext context = this.decodeContext; - ServerMessage message = this.decoder.decode((ByteBuf) msg, this.context, context); + ServerMessage message = this.decoder.decode(frame, this.context, context); if (message != null) { - handleDecoded(ctx, message); + handleDecoded(out, message); } - } else if (msg instanceof ServerMessage) { - ctx.fireChannelRead(msg); - } else { - if (logger.isWarnEnabled()) { - logger.warn("Unknown message type {} on reading", msg.getClass()); - } - ReferenceCountUtil.release(msg); } } @@ -86,22 +95,11 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { if (msg instanceof ClientMessage) { ByteBufAllocator allocator = ctx.alloc(); - Flux encoded; + ClientMessage message = (ClientMessage) msg; + Flux encoded = Flux.from(message.encode(allocator, this.context)); - if (msg instanceof SubsequenceClientMessage) { - SubsequenceClientMessage message = (SubsequenceClientMessage) msg; - - encoded = Flux.from(message.encode(allocator, this.context)); - int envelopeId = message.getEnvelopeId(); - - OperatorUtils.envelope(encoded, allocator, envelopeId, false) - .subscribe(new WriteSubscriber(ctx, promise)); - } else { - encoded = Flux.from(((ClientMessage) msg).encode(allocator, this.context)); - - OperatorUtils.envelope(encoded, allocator, 0, true) - .subscribe(new WriteSubscriber(ctx, promise)); - } + OperatorUtils.envelope(encoded, allocator, sequenceId, message.isCumulative()) + .subscribe(new WriteSubscriber(ctx, promise)); if (msg instanceof PrepareQueryMessage) { setDecodeContext(DecodeContext.prepareQuery()); @@ -118,13 +116,74 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) } } + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof PacketEvent) { + switch ((PacketEvent) evt) { + case RESET_SEQUENCE: + logger.trace("Reset sequence id"); + this.sequenceId.set(0); + break; + case USE_COMPRESSION: + logger.trace("Reset sequence id"); + this.sequenceId.set(0); + + if (context.getCapability().isZstdCompression()) { + enableZstdCompression(ctx); + } else if (context.getCapability().isZlibCompression()) { + enableZlibCompression(ctx); + } else { + logger.warn("Unexpected event compression triggered, no capability found"); + } + break; + default: + // Ignore unknown event + break; + } + } + + ctx.fireUserEventTriggered(evt); + } + + @Override + public void flush(ChannelHandlerContext ctx) { + ctx.flush(); + } + @Override public void channelInactive(ChannelHandlerContext ctx) { decoder.dispose(); ctx.fireChannelInactive(); } - private void handleDecoded(ChannelHandlerContext ctx, ServerMessage msg) { + @Nullable + private ByteBuf decode(ByteBuf in) { + if (frameLength == -1) { + // New frame + if (in.readableBytes() < Packets.SIZE_FIELD_SIZE) { + return null; + } + + frameLength = in.getUnsignedMediumLE(in.readerIndex()) + Packets.NORMAL_HEADER_SIZE; + } + + if (in.readableBytes() < frameLength) { + return null; + } + + in.skipBytes(Packets.SIZE_FIELD_SIZE); + + int sequenceId = in.readUnsignedByte(); + ByteBuf frame = in.readRetainedSlice(frameLength - Packets.NORMAL_HEADER_SIZE); + + logger.trace("Decoded frame with sequence id: {}, total size: {}", sequenceId, frameLength); + this.sequenceId.set(sequenceId + 1); + this.frameLength = -1; + + return frame; + } + + private void handleDecoded(List out, ServerMessage msg) { if (msg instanceof ServerStatusMessage) { this.context.setServerStatuses(((ServerStatusMessage) msg).getServerStatuses()); } @@ -159,7 +218,7 @@ private void handleDecoded(ChannelHandlerContext ctx, ServerMessage msg) { } // Generic handle. - ctx.fireChannelRead(msg); + out.add(msg); } private void setDecodeContext(DecodeContext context) { @@ -168,4 +227,59 @@ private void setDecodeContext(DecodeContext context) { logger.debug("Decode context change to {}", context); } } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, + ChannelPromise promise) { + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + ctx.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { + ctx.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) { + ctx.read(); + } + + private static void enableZstdCompression(ChannelHandlerContext ctx) { + CompressionDuplexCodec handler = new CompressionDuplexCodec( + new ZstdCompressor(3)); + + if (ctx.pipeline().get(CompressionDuplexCodec.NAME) != null) { + logger.warn("Unexpected event, compression already enabled"); + } else { + logger.debug("Compression zstd enabled for subsequent packets"); + ctx.pipeline().addBefore(NAME, CompressionDuplexCodec.NAME, handler); + } + } + + private static void enableZlibCompression(ChannelHandlerContext ctx) { + CompressionDuplexCodec handler = new CompressionDuplexCodec(new ZlibCompressor()); + + if (ctx.pipeline().get(CompressionDuplexCodec.NAME) != null) { + logger.warn("Unexpected event, compression already enabled"); + } else { + logger.debug("Compression zlib enabled for subsequent packets"); + ctx.pipeline().addBefore(NAME, CompressionDuplexCodec.NAME, handler); + } + } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/PacketEvent.java b/src/main/java/io/asyncer/r2dbc/mysql/client/PacketEvent.java new file mode 100644 index 000000000..c8cd53906 --- /dev/null +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/PacketEvent.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.client; + +/** + * A packet event considers how the handler should handle subsequent packets. + */ +enum PacketEvent { + + /** + * Sequence is reset, all sequence IDs should be reset to 0. + */ + RESET_SEQUENCE, + + /** + * Compression is enabled, the handler should decode the next packet as a compression packet. + *

+ * It should just reset the normal sequence ID to 0. + */ + USE_COMPRESSION, +} diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/ReactorNettyClient.java b/src/main/java/io/asyncer/r2dbc/mysql/client/ReactorNettyClient.java index e1917aacb..8abf17c10 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/client/ReactorNettyClient.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/ReactorNettyClient.java @@ -95,9 +95,7 @@ final class ReactorNettyClient implements Client { this.context = context; // Note: encoder/decoder should before reactor bridge. - connection.addHandlerLast(EnvelopeSlicer.NAME, new EnvelopeSlicer()) - .addHandlerLast(MessageDuplexCodec.NAME, - new MessageDuplexCodec(context)); + connection.addHandlerLast(MessageDuplexCodec.NAME, new MessageDuplexCodec(context)); if (ssl.getSslMode().startSsl()) { connection.addHandlerFirst(SslBridgeHandler.NAME, new SslBridgeHandler(context, ssl)); @@ -133,6 +131,10 @@ final class ReactorNettyClient implements Client { logger.debug("Request: {}", message); } + if (message.isSequenceReset()) { + resetSequence(connection); + } + return connection.outbound().sendObject(message); }) .onErrorResume(this::resumeError) @@ -250,7 +252,15 @@ public void sslUnsupported() { @Override public void loginSuccess() { - connection.channel().pipeline().fireUserEventTriggered(Lifecycle.COMMAND); + if (context.getCapability().isCompression()) { + connection.channel().pipeline().fireUserEventTriggered(PacketEvent.USE_COMPRESSION); + } else { + resetSequence(connection); + } + } + + private static void resetSequence(Connection connection) { + connection.channel().pipeline().fireUserEventTriggered(PacketEvent.RESET_SEQUENCE); } @Override diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java b/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java index 952be3917..6ba3f2844 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/SslBridgeHandler.java @@ -96,7 +96,7 @@ public void handlerAdded(ChannelHandlerContext ctx) { } @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { if (evt instanceof SslState) { handleSslState(ctx, (SslState) evt); // Ignore event trigger for next handler, because it used only by this handler. @@ -105,7 +105,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc handleSslCompleted(ctx, (SslHandshakeCompletionEvent) evt); } - super.userEventTriggered(ctx, evt); + ctx.fireUserEventTriggered(evt); } private void handleSslCompleted(ChannelHandlerContext ctx, SslHandshakeCompletionEvent evt) { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/WriteSubscriber.java b/src/main/java/io/asyncer/r2dbc/mysql/client/WriteSubscriber.java index cc5e14cbb..ee085cfa3 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/client/WriteSubscriber.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/WriteSubscriber.java @@ -27,6 +27,11 @@ * streaming {@link ByteBuf}s. *

* It ensures {@link #promise} will be complete. + *

+ * Note: flush is required due to the message may be encoded by another thread, like: + * {@link io.asyncer.r2dbc.mysql.message.client.LocalInfileResponse LocalInfileResponse}, + * {@link io.asyncer.r2dbc.mysql.message.client.PreparedExecuteMessage PreparedExecuteMessage} (Blob/Clob), + * etc. */ final class WriteSubscriber implements CoreSubscriber { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/ZlibCompressor.java b/src/main/java/io/asyncer/r2dbc/mysql/client/ZlibCompressor.java new file mode 100644 index 000000000..5bd749a44 --- /dev/null +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/ZlibCompressor.java @@ -0,0 +1,179 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.client; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.DecoderException; + +import java.util.zip.DataFormatException; +import java.util.zip.Deflater; +import java.util.zip.Inflater; + +/** + * An implementation of {@link Compressor} that uses the zlib compression algorithm. + * + * @see io.netty.handler.codec.compression Netty Compression Codecs + */ +final class ZlibCompressor implements Compressor { + + /** + * The maximum size of input buffer and the maximum initial capacity of the compressed data buffer. + *

+ * Note: uncompressed size is already known, so the buffer should be allocated with the exact size. + */ + private static final int MAX_CHUNK_SIZE = 65536; + + private final Deflater deflater = new Deflater(); + + private final Inflater inflater = new Inflater(); + + @Override + public ByteBuf compress(ByteBuf buf) { + int len = buf.readableBytes(); + + if (len == 0) { + return buf.alloc().buffer(0, 0); + } + + try { + if (buf.hasArray()) { + byte[] input = buf.array(); + int offset = buf.arrayOffset() + buf.readerIndex(); + ByteBuf out = buf.alloc().heapBuffer(Math.min(len, MAX_CHUNK_SIZE)); + + deflater.setInput(input, offset, len); + deflater.finish(); + deflateAll(out, len); + + return out; + } else { + byte[] input = new byte[Math.min(len, MAX_CHUNK_SIZE)]; + int readerIndex = buf.readerIndex(); + int writerIndex = buf.writerIndex(); + ByteBuf out = buf.alloc().heapBuffer(Math.min(len, MAX_CHUNK_SIZE)); + + while (writerIndex - readerIndex > 0) { + int numBytes = Math.min(input.length, writerIndex - readerIndex); + + buf.getBytes(readerIndex, input, 0, numBytes); + deflater.setInput(input, 0, numBytes); + readerIndex += numBytes; + deflateAll(out, len); + } + + deflater.finish(); + deflateAll(out, len); + + return out; + } + } finally { + deflater.reset(); + } + } + + @Override + public ByteBuf decompress(ByteBuf buf, int uncompressedSize) { + int len = buf.readableBytes(); + + if (len == 0) { + return buf.alloc().buffer(0, 0); + } + + try { + if (buf.hasArray()) { + byte[] input = buf.array(); + int offset = buf.arrayOffset() + buf.readerIndex(); + ByteBuf out = buf.alloc().heapBuffer(uncompressedSize); + + inflater.setInput(input, offset, len); + inflateAll(out); + + return out; + } else { + byte[] input = new byte[Math.min(len, MAX_CHUNK_SIZE)]; + + int readerIndex = buf.readerIndex(); + int writerIndex = buf.writerIndex(); + ByteBuf out = buf.alloc().heapBuffer(uncompressedSize); + + while (writerIndex - readerIndex > 0) { + int numBytes = Math.min(input.length, writerIndex - readerIndex); + + buf.getBytes(readerIndex, input, 0, numBytes); + inflater.setInput(input, 0, numBytes); + readerIndex += numBytes; + inflateAll(out); + } + + return out; + } + } catch (DataFormatException e) { + throw new DecoderException("zlib decompress failed", e); + } finally { + inflater.reset(); + } + } + + @Override + public void dispose() { + deflater.end(); + inflater.end(); + } + + private void deflateAll(ByteBuf out, int maxSize) { + while (true) { + deflate(out); + + if (!out.isWritable()) { + int size = out.readableBytes(); + + if (size >= maxSize) { + break; + } + + // Capacity = written size * 2 + if (size > (maxSize >> 1)) { + out.ensureWritable(maxSize - size); + } else { + out.ensureWritable(size); + } + } else if (deflater.needsInput()) { + break; + } + } + } + + private void inflateAll(ByteBuf out) throws DataFormatException { + while (out.isWritable() && !inflater.finished()) { + int wid = out.writerIndex(); + int numBytes = inflater.inflate(out.array(), out.arrayOffset() + wid, out.writableBytes()); + + out.writerIndex(wid + numBytes); + } + } + + private void deflate(ByteBuf out) { + int wid = out.writerIndex(); + int written = deflater.deflate(out.array(), out.arrayOffset() + wid, out.writableBytes()); + + while (written > 0) { + wid += written; + out.writerIndex(wid); + written = deflater.deflate(out.array(), out.arrayOffset() + wid, out.writableBytes()); + } + } +} diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/ZstdCompressor.java b/src/main/java/io/asyncer/r2dbc/mysql/client/ZstdCompressor.java new file mode 100644 index 000000000..b25c8ed0c --- /dev/null +++ b/src/main/java/io/asyncer/r2dbc/mysql/client/ZstdCompressor.java @@ -0,0 +1,60 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.client; + +import com.github.luben.zstd.Zstd; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import java.nio.ByteBuffer; + +import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require; + +/** + * An implementation of {@link Compressor} that uses the Z-standard compression algorithm. + * + * @see Zstandard + */ +final class ZstdCompressor implements Compressor { + + private final int compressionLevel; + + ZstdCompressor(int compressionLevel) { + require( + compressionLevel >= Zstd.minCompressionLevel() && compressionLevel <= Zstd.maxCompressionLevel(), + "compressionLevel must be a value of Z standard compression levels"); + + this.compressionLevel = compressionLevel; + } + + @Override + public ByteBuf compress(ByteBuf buf) { + ByteBuffer buffer = Zstd.compress(buf.nioBuffer(), compressionLevel); + return Unpooled.wrappedBuffer(buffer); + } + + @Override + public ByteBuf decompress(ByteBuf buf, int uncompressedSize) { + ByteBuffer buffer = Zstd.decompress(buf.nioBuffer(), uncompressedSize); + return Unpooled.wrappedBuffer(buffer); + } + + @Override + public void dispose() { + // Do nothing + } +} diff --git a/src/main/java/io/asyncer/r2dbc/mysql/constant/CompressionAlgorithm.java b/src/main/java/io/asyncer/r2dbc/mysql/constant/CompressionAlgorithm.java new file mode 100644 index 000000000..05945bf74 --- /dev/null +++ b/src/main/java/io/asyncer/r2dbc/mysql/constant/CompressionAlgorithm.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.constant; + +/** + * The compression algorithm for client/server communication. + */ +public enum CompressionAlgorithm { + + /** + * Do not use compression protocol. + */ + UNCOMPRESSED, + + /** + * Use zlib compression algorithm for client/server communication. + *

+ * If zlib is not available, the connection will throw an exception when logging in. + */ + ZLIB, + + /** + * Use Z-Standard compression algorithm for client/server communication. + *

+ * If zstd is not available, the connection will throw an exception when logging in. + */ + ZSTD, +} diff --git a/src/main/java/io/asyncer/r2dbc/mysql/constant/Envelopes.java b/src/main/java/io/asyncer/r2dbc/mysql/constant/Packets.java similarity index 57% rename from src/main/java/io/asyncer/r2dbc/mysql/constant/Envelopes.java rename to src/main/java/io/asyncer/r2dbc/mysql/constant/Packets.java index 0bf5db0fa..b36a9d77f 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/constant/Envelopes.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/constant/Packets.java @@ -17,11 +17,11 @@ package io.asyncer.r2dbc.mysql.constant; /** - * Constants for MySQL protocol envelopes (e.g. business layer packages). + * Constants for MySQL protocol packets. *

* WARNING: do NOT use it outer than {@literal r2dbc-mysql}. */ -public final class Envelopes { +public final class Packets { /** * The length of the byte size field, it is 3 bytes. @@ -29,19 +29,26 @@ public final class Envelopes { public static final int SIZE_FIELD_SIZE = 3; /** - * The byte size of header part. + * The max bytes size of payload, value is 16777215. (i.e. max value of int24, (2 ** 24) - 1) */ - public static final int PART_HEADER_SIZE = SIZE_FIELD_SIZE + 1; + public static final int MAX_PAYLOAD_SIZE = 0xFFFFFF; /** - * The max bytes size of each envelope, value is 16777215. (i.e. max value of int24, (2 ** 24) - 1) + * The header size of a compression frame, which includes entire frame size (unsigned int24), compression + * sequence id (unsigned int8) and compressed size (unsigned int24). */ - public static final int MAX_ENVELOPE_SIZE = (1 << (SIZE_FIELD_SIZE << 3)) - 1; + public static final int COMPRESS_HEADER_SIZE = SIZE_FIELD_SIZE + 1 + SIZE_FIELD_SIZE; + + /** + * The header size of a normal frame, which includes entire frame size (unsigned int24) and normal + * sequence id (unsigned int8). + */ + public static final int NORMAL_HEADER_SIZE = SIZE_FIELD_SIZE + 1; /** * The terminal of C-style string or C-style binary data. */ public static final byte TERMINAL = 0; - private Envelopes() { } + private Packets() { } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/FluxEnvelope.java b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/FluxEnvelope.java index 47aa878ed..895b64f27 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/FluxEnvelope.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/FluxEnvelope.java @@ -16,7 +16,7 @@ package io.asyncer.r2dbc.mysql.internal.util; -import io.asyncer.r2dbc.mysql.constant.Envelopes; +import io.asyncer.r2dbc.mysql.constant.Packets; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import org.jetbrains.annotations.Nullable; @@ -28,6 +28,8 @@ import reactor.core.publisher.Operators; import reactor.util.context.Context; +import java.util.concurrent.atomic.AtomicInteger; + /** * An implementation of {@link Flux}{@code <}{@link ByteBuf}{@code >} that considers cumulate buffers as * envelopes of the MySQL socket protocol. @@ -38,26 +40,26 @@ final class FluxEnvelope extends FluxOperator { private final int size; - private final int start; + private final AtomicInteger sequenceId; private final boolean cumulate; - FluxEnvelope(Flux source, ByteBufAllocator alloc, int size, int start, + FluxEnvelope(Flux source, ByteBufAllocator alloc, int size, AtomicInteger sequenceId, boolean cumulate) { super(source); this.alloc = alloc; this.size = size; - this.start = start; + this.sequenceId = sequenceId; this.cumulate = cumulate; } @Override public void subscribe(CoreSubscriber actual) { if (cumulate) { - this.source.subscribe(new CumulateEnvelopeSubscriber(actual, alloc, size, start)); + this.source.subscribe(new CumulateEnvelopeSubscriber(actual, alloc, size, sequenceId)); } else { - this.source.subscribe(new DirectEnvelopeSubscriber(actual, alloc, start)); + this.source.subscribe(new DirectEnvelopeSubscriber(actual, alloc, sequenceId)); } } } @@ -68,16 +70,17 @@ final class DirectEnvelopeSubscriber implements CoreSubscriber, Scannab private final ByteBufAllocator alloc; + private final AtomicInteger sequenceId; + private boolean done; private Subscription s; - private int envelopeId; - - DirectEnvelopeSubscriber(CoreSubscriber actual, ByteBufAllocator alloc, int start) { + DirectEnvelopeSubscriber(CoreSubscriber actual, ByteBufAllocator alloc, + AtomicInteger sequenceId) { this.actual = actual; this.alloc = alloc; - this.envelopeId = start; + this.sequenceId = sequenceId; } @Override @@ -97,9 +100,9 @@ public void onNext(ByteBuf buf) { } try { - ByteBuf header = this.alloc.buffer(Envelopes.PART_HEADER_SIZE) + ByteBuf header = this.alloc.ioBuffer(Packets.NORMAL_HEADER_SIZE) .writeMediumLE(buf.readableBytes()) - .writeByte(this.envelopeId++); + .writeByte(this.sequenceId.getAndIncrement()); this.actual.onNext(header); this.actual.onNext(buf); @@ -172,20 +175,20 @@ final class CumulateEnvelopeSubscriber implements CoreSubscriber, Scann private final int size; + private final AtomicInteger sequenceId; + private boolean done; private Subscription s; private ByteBuf cumulated; - private int envelopeId; - CumulateEnvelopeSubscriber(CoreSubscriber actual, ByteBufAllocator alloc, int size, - int start) { + AtomicInteger sequenceId) { this.actual = actual; this.alloc = alloc; this.size = size; - this.envelopeId = start; + this.sequenceId = sequenceId; } @Override @@ -217,9 +220,9 @@ public void onNext(ByteBuf buf) { while (cumulated.readableBytes() >= this.size) { // It will make the cumulated be shared (e.g. refCnt() > 1), that means // the reallocation of the cumulated may not be safe, see cumulate(...). - this.actual.onNext(this.alloc.buffer(Envelopes.PART_HEADER_SIZE) + this.actual.onNext(this.alloc.ioBuffer(Packets.NORMAL_HEADER_SIZE) .writeMediumLE(this.size) - .writeByte(this.envelopeId++)); + .writeByte(this.sequenceId.getAndIncrement())); this.actual.onNext(cumulated.readRetainedSlice(this.size)); } @@ -275,8 +278,8 @@ public void onComplete() { ByteBuf header = null; try { - header = this.alloc.buffer(Envelopes.PART_HEADER_SIZE); - header.writeMediumLE(size).writeByte(this.envelopeId++); + header = this.alloc.ioBuffer(Packets.NORMAL_HEADER_SIZE); + header.writeMediumLE(size).writeByte(this.sequenceId.getAndIncrement()); } catch (Throwable e) { if (cumulated != null) { cumulated.release(); @@ -356,8 +359,8 @@ private static ByteBuf cumulate(ByteBufAllocator alloc, @Nullable ByteBuf cumula int oldBytes = cumulated.readableBytes(); int bufBytes = buf.readableBytes(); int newBytes = oldBytes + bufBytes; - ByteBuf result = releasing = alloc.buffer(alloc.calculateNewCapacity(newBytes, - Integer.MAX_VALUE)); + int newCapacity = alloc.calculateNewCapacity(newBytes, Integer.MAX_VALUE); + ByteBuf result = releasing = alloc.ioBuffer(newCapacity); // Avoid to calling writeBytes(...) with redundancy check and stack depth comparison. result.setBytes(0, cumulated, cumulated.readerIndex(), oldBytes) diff --git a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/OperatorUtils.java b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/OperatorUtils.java index 299f82790..a0fdce06b 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/OperatorUtils.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/OperatorUtils.java @@ -16,12 +16,14 @@ package io.asyncer.r2dbc.mysql.internal.util; -import io.asyncer.r2dbc.mysql.constant.Envelopes; +import io.asyncer.r2dbc.mysql.constant.Packets; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import reactor.core.Fuseable; import reactor.core.publisher.Flux; +import java.util.concurrent.atomic.AtomicInteger; + import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** @@ -56,12 +58,12 @@ public static Flux discardOnCancel(Flux source) { } public static Flux envelope(Flux source, ByteBufAllocator allocator, - int envelopeIdStart, boolean cumulate) { + AtomicInteger sequenceId, boolean cumulate) { requireNonNull(source, "source must not be null"); requireNonNull(allocator, "allocator must not be null"); + requireNonNull(sequenceId, "sequenceId must not be null"); - return new FluxEnvelope(source, allocator, Envelopes.MAX_ENVELOPE_SIZE, - envelopeIdStart & 0xFF, cumulate); + return new FluxEnvelope(source, allocator, Packets.MAX_PAYLOAD_SIZE, sequenceId, cumulate); } private OperatorUtils() { } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/ReadCompletionHandler.java b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/ReadCompletionHandler.java index 1c72a49df..0d8e89294 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/internal/util/ReadCompletionHandler.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/internal/util/ReadCompletionHandler.java @@ -80,7 +80,7 @@ private void tryRead() { } private void read() { - ByteBuf buf = this.allocator.buffer(this.bufferSize); + ByteBuf buf = this.allocator.ioBuffer(this.bufferSize); ByteBuffer byteBuffer = buf.nioBuffer(buf.writerIndex(), buf.writableBytes()); this.channel.read(byteBuffer, this.position.get(), buf, this); diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/AuthResponse.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/AuthResponse.java index 4c6d2b1de..3aa96cd70 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/AuthResponse.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/AuthResponse.java @@ -28,20 +28,12 @@ */ public final class AuthResponse extends SizedClientMessage implements SubsequenceClientMessage { - private final int envelopeId; - private final byte[] authentication; - public AuthResponse(int envelopeId, byte[] authentication) { - this.envelopeId = envelopeId; + public AuthResponse(byte[] authentication) { this.authentication = requireNonNull(authentication, "authentication must not be null"); } - @Override - public int getEnvelopeId() { - return envelopeId; - } - @Override protected int size() { return authentication.length; @@ -58,17 +50,17 @@ public boolean equals(Object o) { AuthResponse that = (AuthResponse) o; - return envelopeId == that.envelopeId && Arrays.equals(authentication, that.authentication); + return Arrays.equals(authentication, that.authentication); } @Override public int hashCode() { - return 31 * envelopeId + Arrays.hashCode(authentication); + return Arrays.hashCode(authentication); } @Override public String toString() { - return "AuthResponse{envelopeId=" + envelopeId + ", authentication=REDACTED}"; + return "AuthResponse{authentication=REDACTED}"; } @Override diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/ClientMessage.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/ClientMessage.java index 3080da66f..047884a17 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/ClientMessage.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/ClientMessage.java @@ -27,6 +27,24 @@ */ public interface ClientMessage { + /** + * Returns whether the sequence should be reset before encoding this message. + * + * @return {@code true} if the sequence should be reset. + */ + default boolean isSequenceReset() { + return true; + } + + /** + * Returns whether the encoded buffers can be cumulated to maximize the payload size. + * + * @return {@code true} if can be cumulated. + */ + default boolean isCumulative() { + return true; + } + /** * Encode a message into {@link ByteBuf}s. * diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse.java index ca05f9831..c7b135a34 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse.java @@ -22,7 +22,7 @@ import java.nio.charset.Charset; import java.util.Map; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; /** * An abstraction of {@link SubsequenceClientMessage} considers handshake response. @@ -33,24 +33,24 @@ public interface HandshakeResponse extends SubsequenceClientMessage { * Construct an instance of {@link HandshakeResponse}, it is implemented by the protocol version that is * given by {@link Capability}. * - * @param envelopeId the beginning envelope ID of this message. - * @param capability the current {@link Capability}. - * @param collationId the {@code CharCollation} ID, or 0 if server does not return a collation ID. - * @param user the username for login. - * @param authentication the password authentication for login. - * @param authType the authentication plugin type. - * @param database the connecting database, may be empty. - * @param attributes the connecting attributes. + * @param capability the current {@link Capability}. + * @param collationId the {@code CharCollation} ID, or 0 if server does not return. + * @param user the username for login. + * @param authentication the password authentication for login. + * @param authType the authentication plugin type. + * @param database the connecting database, may be empty. + * @param attributes the connecting attributes. + * @param zstdCompressionLevel the Zstd compression level. * @return the instance implemented by the specified protocol version. */ - static HandshakeResponse from(int envelopeId, Capability capability, int collationId, String user, - byte[] authentication, String authType, String database, Map attributes) { + static HandshakeResponse from(Capability capability, int collationId, String user, byte[] authentication, + String authType, String database, Map attributes, int zstdCompressionLevel) { if (capability.isProtocol41()) { - return new HandshakeResponse41(envelopeId, capability, collationId, user, authentication, - authType, database, attributes); + return new HandshakeResponse41(capability, collationId, user, authentication, authType, database, + attributes, zstdCompressionLevel); } - return new HandshakeResponse320(envelopeId, capability, user, authentication, database); + return new HandshakeResponse320(capability, user, authentication, database); } /** diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse320.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse320.java index c93705ea6..e3547faeb 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse320.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse320.java @@ -23,7 +23,7 @@ import java.nio.charset.Charset; import java.util.Arrays; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** @@ -41,19 +41,14 @@ final class HandshakeResponse320 extends ScalarClientMessage implements Handshak private final String database; - HandshakeResponse320(int envelopeId, Capability capability, String user, byte[] authentication, + HandshakeResponse320(Capability capability, String user, byte[] authentication, String database) { - this.header = new SslRequest320(envelopeId, capability); + this.header = new SslRequest320(capability); this.user = requireNonNull(user, "user must not be null"); this.authentication = requireNonNull(authentication, "authentication must not be null"); this.database = requireNonNull(database, "database must not be null"); } - @Override - public int getEnvelopeId() { - return header.getEnvelopeId(); - } - @Override public boolean equals(Object o) { if (this == o) { @@ -79,8 +74,7 @@ public int hashCode() { @Override public String toString() { - return "HandshakeResponse320{envelopeId=" + header.getEnvelopeId() + - ", capability=" + header.getCapability() + ", user='" + user + + return "HandshakeResponse320{capability=" + header.getCapability() + ", user='" + user + "', authentication=REDACTED, database='" + database + "'}"; } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse41.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse41.java index f727364f5..810126bba 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse41.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/HandshakeResponse41.java @@ -49,21 +49,17 @@ final class HandshakeResponse41 extends ScalarClientMessage implements Handshake private final Map attributes; - // private final byte zStdCompressionLevel; // When Z-Standard compression supporting + private final int zstdCompressionLevel; - HandshakeResponse41(int envelopeId, Capability capability, int collationId, String user, - byte[] authentication, String authType, String database, Map attributes) { - this.header = new SslRequest41(envelopeId, capability, collationId); + HandshakeResponse41(Capability capability, int collationId, String user, byte[] authentication, + String authType, String database, Map attributes, int zstdCompressionLevel) { + this.header = new SslRequest41(capability, collationId); this.user = requireNonNull(user, "user must not be null"); this.authentication = requireNonNull(authentication, "authentication must not be null"); this.database = requireNonNull(database, "database must not be null"); this.authType = requireNonNull(authType, "authType must not be null"); this.attributes = requireNonNull(attributes, "attributes must not be null"); - } - - @Override - public int getEnvelopeId() { - return header.getEnvelopeId(); + this.zstdCompressionLevel = zstdCompressionLevel; } @Override @@ -71,15 +67,16 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + if (!(o instanceof HandshakeResponse41)) { return false; } HandshakeResponse41 that = (HandshakeResponse41) o; - return header.equals(that.header) && user.equals(that.user) && - Arrays.equals(authentication, that.authentication) && authType.equals(that.authType) && - database.equals(that.database) && attributes.equals(that.attributes); + return zstdCompressionLevel == that.zstdCompressionLevel && header.equals(that.header) && + user.equals(that.user) && Arrays.equals(authentication, that.authentication) && + authType.equals(that.authType) && database.equals(that.database) && + attributes.equals(that.attributes); } @Override @@ -89,16 +86,18 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(authentication); result = 31 * result + authType.hashCode(); result = 31 * result + database.hashCode(); - return 31 * result + attributes.hashCode(); + result = 31 * result + attributes.hashCode(); + return 31 * result + zstdCompressionLevel; } @Override public String toString() { - return "HandshakeResponse41{envelopeId=" + header.getEnvelopeId() + - ", capability=" + header.getCapability() + + return "HandshakeResponse41{capability=" + header.getCapability() + ", collationId=" + header.getCollationId() + ", user='" + user + "', authentication=REDACTED, authType='" + authType + - "', database='" + database + "', attributes=" + attributes + '}'; + "', database='" + database + "', attributes=" + attributes + + ", zstdCompressionLevel=" + zstdCompressionLevel + + '}'; } @Override @@ -131,6 +130,10 @@ protected void writeTo(ByteBuf buf, ConnectionContext context) { if (capability.isConnectionAttributesAllowed()) { writeAttrs(buf, charset); } + + if (capability.isZstdCompression()) { + buf.writeByte(zstdCompressionLevel); + } } private void writeAttrs(ByteBuf buf, Charset charset) { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/LocalInfileResponse.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/LocalInfileResponse.java index c9ca2419f..63a360f49 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/LocalInfileResponse.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/LocalInfileResponse.java @@ -38,20 +38,22 @@ */ public final class LocalInfileResponse implements SubsequenceClientMessage { - private final int envelopeId; - private final String path; private final SynchronousSink errorSink; - public LocalInfileResponse(int envelopeId, String path, SynchronousSink errorSink) { + public LocalInfileResponse(String path, SynchronousSink errorSink) { requireNonNull(path, "path must not be null"); - this.envelopeId = envelopeId; this.path = path; this.errorSink = errorSink; } + @Override + public boolean isCumulative() { + return false; + } + @Override public Flux encode(ByteBufAllocator allocator, ConnectionContext context) { return Flux.defer(() -> { @@ -93,11 +95,6 @@ public Flux encode(ByteBufAllocator allocator, ConnectionContext contex }); } - @Override - public int getEnvelopeId() { - return envelopeId; - } - @Override public boolean equals(Object o) { if (this == o) { @@ -109,17 +106,16 @@ public boolean equals(Object o) { LocalInfileResponse that = (LocalInfileResponse) o; - return envelopeId == that.envelopeId && path.equals(that.path); + return path.equals(that.path); } @Override public int hashCode() { - return 31 * envelopeId + path.hashCode(); + return path.hashCode(); } @Override public String toString() { - return "LocalInfileResponse{envelopeId=" + envelopeId + - ", path='" + path + "'}"; + return "LocalInfileResponse{path='" + path + "'}"; } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest.java index 16420a412..81cf4eebf 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest.java @@ -36,18 +36,17 @@ public interface SslRequest extends SubsequenceClientMessage { * Construct an instance of {@link SslRequest}, it is implemented by the protocol version that is given by * {@link Capability}. * - * @param envelopeId the beginning envelope ID of this message. * @param capability the current {@link Capability}. * @param collationId the {@code CharCollation} ID, or 0 if server does not return a collation ID. * @return the instance implemented by the specified protocol version. */ - static SslRequest from(int envelopeId, Capability capability, int collationId) { + static SslRequest from(Capability capability, int collationId) { require(capability.isSslEnabled(), "capability must be SSL enabled"); if (capability.isProtocol41()) { - return new SslRequest41(envelopeId, capability, collationId); + return new SslRequest41(capability, collationId); } - return new SslRequest320(envelopeId, capability); + return new SslRequest320(capability); } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest320.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest320.java index 096b2c46e..216476189 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest320.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest320.java @@ -17,7 +17,7 @@ package io.asyncer.r2dbc.mysql.message.client; import io.asyncer.r2dbc.mysql.Capability; -import io.asyncer.r2dbc.mysql.constant.Envelopes; +import io.asyncer.r2dbc.mysql.constant.Packets; import io.netty.buffer.ByteBuf; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require; @@ -27,24 +27,16 @@ */ final class SslRequest320 extends SizedClientMessage implements SslRequest { - private static final int SIZE = Short.BYTES + Envelopes.SIZE_FIELD_SIZE; - - private final int envelopeId; + private static final int SIZE = Short.BYTES + Packets.SIZE_FIELD_SIZE; private final Capability capability; - SslRequest320(int envelopeId, Capability capability) { + SslRequest320(Capability capability) { require(!capability.isProtocol41(), "protocol 4.1 capability should never be set"); - this.envelopeId = envelopeId; this.capability = capability; } - @Override - public int getEnvelopeId() { - return envelopeId; - } - @Override public Capability getCapability() { return capability; @@ -61,18 +53,17 @@ public boolean equals(Object o) { SslRequest320 that = (SslRequest320) o; - return envelopeId == that.envelopeId && capability.equals(that.capability); + return capability.equals(that.capability); } @Override public int hashCode() { - return 31 * envelopeId + capability.hashCode(); + return capability.hashCode(); } @Override public String toString() { - return "SslRequest320{envelopeId=" + envelopeId + - ", capability=" + capability + '}'; + return "SslRequest320{capability=" + capability + '}'; } @Override @@ -84,6 +75,6 @@ protected int size() { protected void writeTo(ByteBuf buf) { // Protocol 3.20 only allows low 16-bits capabilities. buf.writeShortLE(capability.getBaseBitmap() & 0xFFFF) - .writeMediumLE(Envelopes.MAX_ENVELOPE_SIZE); + .writeMediumLE(Packets.MAX_PAYLOAD_SIZE); } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest41.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest41.java index bca7e099c..2e270ae29 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest41.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/SslRequest41.java @@ -17,7 +17,7 @@ package io.asyncer.r2dbc.mysql.message.client; import io.asyncer.r2dbc.mysql.Capability; -import io.asyncer.r2dbc.mysql.constant.Envelopes; +import io.asyncer.r2dbc.mysql.constant.Packets; import io.netty.buffer.ByteBuf; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.require; @@ -34,25 +34,17 @@ final class SslRequest41 extends SizedClientMessage implements SslRequest { private static final int BUF_SIZE = Integer.BYTES + Integer.BYTES + Byte.BYTES + RESERVED_SIZE + MARIA_DB_CAPABILITY_SIZE; - private final int envelopeId; - private final Capability capability; private final int collationId; - SslRequest41(int envelopeId, Capability capability, int collationId) { + SslRequest41(Capability capability, int collationId) { require(collationId > 0, "collationId must be a positive integer"); - this.envelopeId = envelopeId; this.capability = capability; this.collationId = collationId; } - @Override - public int getEnvelopeId() { - return envelopeId; - } - @Override public boolean equals(Object o) { if (this == o) { @@ -64,22 +56,19 @@ public boolean equals(Object o) { SslRequest41 that = (SslRequest41) o; - return envelopeId == that.envelopeId && - collationId == that.collationId && + return collationId == that.collationId && capability.equals(that.capability); } @Override public int hashCode() { - int result = 31 * envelopeId + capability.hashCode(); + int result = capability.hashCode(); return 31 * result + collationId; } @Override public String toString() { - return "SslRequest41{envelopeId=" + envelopeId + - ", capability=" + capability + - ", collationId=" + collationId + '}'; + return "SslRequest41{capability=" + capability + ", collationId=" + collationId + '}'; } @Override @@ -95,7 +84,7 @@ protected int size() { @Override protected void writeTo(ByteBuf buf) { buf.writeIntLE(capability.getBaseBitmap()) - .writeIntLE(Envelopes.MAX_ENVELOPE_SIZE) + .writeIntLE(Packets.MAX_PAYLOAD_SIZE) .writeByte(collationId & 0xFF); // only low 8-bits if (capability.isMariaDb()) { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/client/SubsequenceClientMessage.java b/src/main/java/io/asyncer/r2dbc/mysql/message/client/SubsequenceClientMessage.java index 091e7a57b..fb2364b92 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/client/SubsequenceClientMessage.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/client/SubsequenceClientMessage.java @@ -24,10 +24,8 @@ */ public interface SubsequenceClientMessage extends ClientMessage { - /** - * Gets the current envelope ID used to serialize subsequent request messages. - * - * @return the current envelope ID. - */ - int getEnvelopeId(); + @Override + default boolean isSequenceReset() { + return false; + } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/AuthMoreDataMessage.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/AuthMoreDataMessage.java index 543cbb8e3..ba987287e 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/AuthMoreDataMessage.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/AuthMoreDataMessage.java @@ -25,27 +25,20 @@ public final class AuthMoreDataMessage implements ServerMessage { private static final byte AUTH_SUCCEED = 3; - private final int envelopeId; - private final boolean failed; - private AuthMoreDataMessage(int envelopeId, boolean failed) { - this.envelopeId = envelopeId; + private AuthMoreDataMessage(boolean failed) { this.failed = failed; } - public int getEnvelopeId() { - return envelopeId; - } - public boolean isFailed() { return failed; } - static AuthMoreDataMessage decode(int envelopeId, ByteBuf buf) { + static AuthMoreDataMessage decode(ByteBuf buf) { buf.skipBytes(1); // auth more data message header, 0x01 - return new AuthMoreDataMessage(envelopeId, buf.readByte() != AUTH_SUCCEED); + return new AuthMoreDataMessage(buf.readByte() != AUTH_SUCCEED); } @Override @@ -59,16 +52,16 @@ public boolean equals(Object o) { AuthMoreDataMessage that = (AuthMoreDataMessage) o; - return envelopeId == that.envelopeId && failed == that.failed; + return failed == that.failed; } @Override public int hashCode() { - return (envelopeId << 1) | (failed ? 1 : 0); + return (failed ? 1 : 0); } @Override public String toString() { - return "AuthMoreDataMessage{envelopeId=" + envelopeId + ", failed=" + failed + '}'; + return "AuthMoreDataMessage{failed=" + failed + '}'; } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/ChangeAuthMessage.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/ChangeAuthMessage.java index dc3a1a142..ba1d38479 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/ChangeAuthMessage.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/ChangeAuthMessage.java @@ -21,7 +21,7 @@ import java.util.Arrays; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** @@ -29,22 +29,15 @@ */ public final class ChangeAuthMessage implements ServerMessage { - private final int envelopeId; - private final String authType; private final byte[] salt; - private ChangeAuthMessage(int envelopeId, String authType, byte[] salt) { - this.envelopeId = envelopeId; + private ChangeAuthMessage(String authType, byte[] salt) { this.authType = requireNonNull(authType, "authType must not be null"); this.salt = requireNonNull(salt, "salt must not be null"); } - public int getEnvelopeId() { - return envelopeId; - } - public String getAuthType() { return authType; } @@ -58,29 +51,26 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + if (!(o instanceof ChangeAuthMessage)) { return false; } ChangeAuthMessage that = (ChangeAuthMessage) o; - return envelopeId == that.envelopeId && authType.equals(that.authType) && - Arrays.equals(salt, that.salt); + return authType.equals(that.authType) && Arrays.equals(salt, that.salt); } @Override public int hashCode() { - int result = envelopeId; - result = 31 * result + authType.hashCode(); - return 31 * result + Arrays.hashCode(salt); + return 31 * authType.hashCode() + Arrays.hashCode(salt); } @Override public String toString() { - return "ChangeAuthMessage{envelopeId=" + envelopeId + ", authType='" + authType + "', salt=REDACTED}"; + return "ChangeAuthMessage{authType='" + authType + "', salt=REDACTED}"; } - static ChangeAuthMessage decode(int envelopeId, ByteBuf buf) { + static ChangeAuthMessage decode(ByteBuf buf) { buf.skipBytes(1); // skip generic header 0xFE of change authentication messages String authType = HandshakeHeader.readCStringAscii(buf); @@ -90,6 +80,6 @@ static ChangeAuthMessage decode(int envelopeId, ByteBuf buf) { ByteBufUtil.getBytes(buf); // The terminal character has been removed from salt. - return new ChangeAuthMessage(envelopeId, authType, salt); + return new ChangeAuthMessage(authType, salt); } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeHeader.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeHeader.java index 6ad92fea9..6785a930f 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeHeader.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeHeader.java @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeRequest.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeRequest.java index 81cb22eb1..eed96afca 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeRequest.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeRequest.java @@ -32,13 +32,6 @@ public interface HandshakeRequest extends ServerMessage { */ HandshakeHeader getHeader(); - /** - * Get the envelope identifier of this message packet. - * - * @return envelope identifier. - */ - int getEnvelopeId(); - /** * Get the server-side capability. * @@ -63,19 +56,18 @@ public interface HandshakeRequest extends ServerMessage { /** * Decode a {@link HandshakeRequest} from a envelope {@link ByteBuf}. * - * @param envelopeId envelope identifier. - * @param buf the {@link ByteBuf}. + * @param buf the {@link ByteBuf}. * @return decoded {@link HandshakeRequest}. */ - static HandshakeRequest decode(int envelopeId, ByteBuf buf) { + static HandshakeRequest decode(ByteBuf buf) { HandshakeHeader header = HandshakeHeader.decode(buf); int version = header.getProtocolVersion(); switch (version) { case 10: - return HandshakeV10Request.decode(envelopeId, buf, header); + return HandshakeV10Request.decode(buf, header); case 9: - return HandshakeV9Request.decode(envelopeId, buf, header); + return HandshakeV9Request.decode(buf, header); } throw new R2dbcPermissionDeniedException("Does not support handshake protocol version " + version); diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeV10Request.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeV10Request.java index 47cd0bdc4..5f5a1de67 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeV10Request.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeV10Request.java @@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; /** @@ -42,8 +42,6 @@ final class HandshakeV10Request implements HandshakeRequest, ServerStatusMessage private final HandshakeHeader header; - private final int envelopeId; - private final byte[] salt; private final Capability serverCapability; @@ -52,10 +50,9 @@ final class HandshakeV10Request implements HandshakeRequest, ServerStatusMessage private final String authType; - private HandshakeV10Request(HandshakeHeader header, int envelopeId, byte[] salt, + private HandshakeV10Request(HandshakeHeader header, byte[] salt, Capability serverCapability, short serverStatuses, String authType) { this.header = requireNonNull(header, "header must not be null"); - this.envelopeId = envelopeId; this.salt = requireNonNull(salt, "salt must not be null"); this.serverCapability = requireNonNull(serverCapability, "serverCapability must not be null"); this.serverStatuses = serverStatuses; @@ -67,11 +64,6 @@ public HandshakeHeader getHeader() { return header; } - @Override - public int getEnvelopeId() { - return envelopeId; - } - @Override public byte[] getSalt() { return salt; @@ -97,35 +89,35 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + if (!(o instanceof HandshakeV10Request)) { return false; } HandshakeV10Request that = (HandshakeV10Request) o; - return envelopeId == that.envelopeId && serverStatuses == that.serverStatuses && - header.equals(that.header) && Arrays.equals(salt, that.salt) && - serverCapability.equals(that.serverCapability) && authType.equals(that.authType); + return serverStatuses == that.serverStatuses && header.equals(that.header) && + Arrays.equals(salt, that.salt) && serverCapability.equals(that.serverCapability) && + authType.equals(that.authType); } @Override public int hashCode() { - int hash = 31 * header.hashCode() + envelopeId; - hash = 31 * hash + Arrays.hashCode(salt); - hash = 31 * hash + serverCapability.hashCode(); - hash = 31 * hash + serverStatuses; - return 31 * hash + authType.hashCode(); + int result = header.hashCode(); + result = 31 * result + Arrays.hashCode(salt); + result = 31 * result + serverCapability.hashCode(); + result = 31 * result + (int) serverStatuses; + return 31 * result + authType.hashCode(); } @Override public String toString() { - return "HandshakeV10Request{header=" + header + ", envelopeId=" + envelopeId + + return "HandshakeV10Request{header=" + header + ", salt=REDACTED, serverCapability=" + serverCapability + ", serverStatuses=" + serverStatuses + ", authType='" + authType + "'}"; } - static HandshakeV10Request decode(int envelopeId, ByteBuf buf, HandshakeHeader header) { - Builder builder = new Builder(envelopeId, header); + static HandshakeV10Request decode(ByteBuf buf, HandshakeHeader header) { + Builder builder = new Builder(header); ByteBuf salt = buf.alloc().buffer(); try { @@ -194,8 +186,6 @@ static HandshakeV10Request decode(int envelopeId, ByteBuf buf, HandshakeHeader h private static final class Builder { - private final int envelopeId; - private final HandshakeHeader header; private String authType; @@ -206,14 +196,12 @@ private static final class Builder { private short serverStatuses; - private Builder(int envelopeId, HandshakeHeader header) { - this.envelopeId = envelopeId; + private Builder(HandshakeHeader header) { this.header = header; } HandshakeV10Request build() { - return new HandshakeV10Request(header, envelopeId, salt, serverCapability, serverStatuses, - authType); + return new HandshakeV10Request(header, salt, serverCapability, serverStatuses, authType); } void authType(String authType) { diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeV9Request.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeV9Request.java index ea34e1c55..b92d29256 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeV9Request.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/HandshakeV9Request.java @@ -23,7 +23,7 @@ import java.util.Arrays; -import static io.asyncer.r2dbc.mysql.constant.Envelopes.TERMINAL; +import static io.asyncer.r2dbc.mysql.constant.Packets.TERMINAL; import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull; import static io.asyncer.r2dbc.mysql.internal.util.InternalArrays.EMPTY_BYTES; @@ -36,13 +36,10 @@ final class HandshakeV9Request implements HandshakeRequest { private final HandshakeHeader header; - private final int envelopeId; - private final byte[] salt; - private HandshakeV9Request(HandshakeHeader header, int envelopeId, byte[] salt) { + private HandshakeV9Request(HandshakeHeader header, byte[] salt) { this.header = requireNonNull(header, "header must not be null"); - this.envelopeId = envelopeId; this.salt = requireNonNull(salt, "salt must not be null"); } @@ -51,11 +48,6 @@ public HandshakeHeader getHeader() { return header; } - @Override - public int getEnvelopeId() { - return envelopeId; - } - @Override public Capability getServerCapability() { return SERVER_CAPABILITY; @@ -76,31 +68,31 @@ public boolean equals(Object o) { if (this == o) { return true; } - if (o == null || getClass() != o.getClass()) { + if (!(o instanceof HandshakeV9Request)) { return false; } HandshakeV9Request that = (HandshakeV9Request) o; - return envelopeId == that.envelopeId && header.equals(that.header) && Arrays.equals(salt, that.salt); + return header.equals(that.header) && Arrays.equals(salt, that.salt); } @Override public int hashCode() { - int hash = 31 * header.hashCode() + envelopeId; - return 31 * hash + Arrays.hashCode(salt); + int result = header.hashCode(); + return 31 * result + Arrays.hashCode(salt); } @Override public String toString() { - return "HandshakeV9Request{header=" + header + ", envelopeId=" + envelopeId + ", salt=REDACTED}"; + return "HandshakeV9Request{header=" + header + ", salt=REDACTED}"; } - static HandshakeV9Request decode(int envelopeId, ByteBuf buf, HandshakeHeader header) { + static HandshakeV9Request decode(ByteBuf buf, HandshakeHeader header) { int bytes = buf.readableBytes(); if (bytes <= 0) { - return new HandshakeV9Request(header, envelopeId, EMPTY_BYTES); + return new HandshakeV9Request(header, EMPTY_BYTES); } byte[] salt; @@ -111,6 +103,6 @@ static HandshakeV9Request decode(int envelopeId, ByteBuf buf, HandshakeHeader he salt = ByteBufUtil.getBytes(buf); } - return new HandshakeV9Request(header, envelopeId, salt); + return new HandshakeV9Request(header, salt); } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/LargeFieldReader.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/LargeFieldReader.java index 9842687b1..c649f99bd 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/LargeFieldReader.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/LargeFieldReader.java @@ -16,7 +16,7 @@ package io.asyncer.r2dbc.mysql.message.server; -import io.asyncer.r2dbc.mysql.constant.Envelopes; +import io.asyncer.r2dbc.mysql.constant.Packets; import io.asyncer.r2dbc.mysql.internal.util.NettyBufferUtils; import io.asyncer.r2dbc.mysql.internal.util.VarIntUtils; import io.asyncer.r2dbc.mysql.message.FieldValue; @@ -136,7 +136,7 @@ protected void deallocate() { private List readSlice(ByteBuf current, long length) { ByteBuf buf = current; List results = new ArrayList<>(Math.max( - (int) Math.min(Long.divideUnsigned(length, Envelopes.MAX_ENVELOPE_SIZE) + 2, Byte.MAX_VALUE), + (int) Math.min(Long.divideUnsigned(length, Packets.MAX_PAYLOAD_SIZE) + 2, Byte.MAX_VALUE), 10 )); long totalSize = 0; diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/LocalInfileRequest.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/LocalInfileRequest.java index e493e2bf8..433059f1b 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/LocalInfileRequest.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/LocalInfileRequest.java @@ -25,26 +25,19 @@ */ public final class LocalInfileRequest implements ServerMessage { - private final int envelopeId; - private final String path; - private LocalInfileRequest(int envelopeId, String path) { - this.envelopeId = envelopeId; + private LocalInfileRequest(String path) { this.path = path; } - public int getEnvelopeId() { - return envelopeId; - } - public String getPath() { return path; } - static LocalInfileRequest decode(int envelopeId, ByteBuf buf, ConnectionContext context) { + static LocalInfileRequest decode(ByteBuf buf, ConnectionContext context) { buf.skipBytes(1); // Constant 0xFB - return new LocalInfileRequest(envelopeId, buf.toString(context.getClientCollation().getCharset())); + return new LocalInfileRequest(buf.toString(context.getClientCollation().getCharset())); } @Override @@ -58,17 +51,16 @@ public boolean equals(Object o) { LocalInfileRequest that = (LocalInfileRequest) o; - return envelopeId == that.envelopeId && path.equals(that.path); + return path.equals(that.path); } @Override public int hashCode() { - return 31 * envelopeId + path.hashCode(); + return path.hashCode(); } @Override public String toString() { - return "LocalInfileRequest{envelopeId=" + envelopeId + - ", path='" + path + "'}"; + return "LocalInfileRequest{path='" + path + "'}"; } } diff --git a/src/main/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoder.java b/src/main/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoder.java index f81a06200..1c8577ab9 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoder.java +++ b/src/main/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoder.java @@ -17,7 +17,7 @@ package io.asyncer.r2dbc.mysql.message.server; import io.asyncer.r2dbc.mysql.ConnectionContext; -import io.asyncer.r2dbc.mysql.constant.Envelopes; +import io.asyncer.r2dbc.mysql.constant.Packets; import io.asyncer.r2dbc.mysql.internal.util.NettyBufferUtils; import io.asyncer.r2dbc.mysql.internal.util.VarIntUtils; import io.netty.buffer.ByteBuf; @@ -58,24 +58,25 @@ public final class ServerMessageDecoder { /** * Decode a server-side message from {@link #parts} and current envelope. * - * @param envelope the current envelope. + * @param payload the payload of the current packet. * @param context the connection context. * @param decodeContext the decode context. * @return the server-side message, or {@code null} if {@code envelope} is not last packet. */ @Nullable - public ServerMessage decode(ByteBuf envelope, ConnectionContext context, DecodeContext decodeContext) { - requireNonNull(envelope, "envelope must not be null"); + public ServerMessage decode(ByteBuf payload, ConnectionContext context, DecodeContext decodeContext) { + requireNonNull(payload, "payload must not be null"); requireNonNull(context, "context must not be null"); requireNonNull(decodeContext, "decodeContext must not be null"); - List buffers = this.parts; - Byte id = readNotFinish(buffers, envelope); - if (id == null) { + parts.add(payload); + + if (payload.readableBytes() == Packets.MAX_PAYLOAD_SIZE) { + // Not last packet. return null; } - return decodeMessage(buffers, id.intValue() & 0xFF, context, decodeContext); + return decodeMessage(parts, context, decodeContext); } /** @@ -91,8 +92,8 @@ public void dispose() { } @Nullable - private static ServerMessage decodeMessage(List buffers, int envelopeId, - ConnectionContext context, DecodeContext decodeContext) { + private static ServerMessage decodeMessage(List buffers, ConnectionContext context, + DecodeContext decodeContext) { if (decodeContext instanceof ResultDecodeContext) { return decodeResult(buffers, context, (ResultDecodeContext) decodeContext); } @@ -104,14 +105,14 @@ private static ServerMessage decodeMessage(List buffers, int envelopeId try { if (decodeContext instanceof CommandDecodeContext) { - return decodeCommandMessage(envelopeId, combined, context); + return decodeCommandMessage(combined, context); } else if (decodeContext instanceof PreparedMetadataDecodeContext) { return decodePreparedMetadata(combined, context, (PreparedMetadataDecodeContext) decodeContext); } else if (decodeContext instanceof PrepareQueryDecodeContext) { return decodePrepareQuery(combined); } else if (decodeContext instanceof LoginDecodeContext) { - return decodeLogin(envelopeId, combined, context); + return decodeLogin(combined, context); } } finally { combined.release(); @@ -194,8 +195,7 @@ private static ServerMessage decodePrepareQuery(ByteBuf buf) { " on prepare query phase"); } - private static ServerMessage decodeCommandMessage(int envelopeId, ByteBuf buf, - ConnectionContext context) { + private static ServerMessage decodeCommandMessage(ByteBuf buf, ConnectionContext context) { short header = buf.getUnsignedByte(buf.readerIndex()); switch (header) { case ERROR: @@ -221,8 +221,9 @@ private static ServerMessage decodeCommandMessage(int envelopeId, ByteBuf buf, } case LOCAL_INFILE: if (buf.readableBytes() > 1) { - return LocalInfileRequest.decode(envelopeId, buf, context); + return LocalInfileRequest.decode(buf, context); } + break; } if (VarIntUtils.checkNextVarInt(buf) == 0) { @@ -236,7 +237,7 @@ private static ServerMessage decodeCommandMessage(int envelopeId, ByteBuf buf, " on command phase"); } - private static ServerMessage decodeLogin(int envelopeId, ByteBuf buf, ConnectionContext context) { + private static ServerMessage decodeLogin(ByteBuf buf, ConnectionContext context) { short header = buf.getUnsignedByte(buf.readerIndex()); switch (header) { case OK: @@ -246,10 +247,10 @@ private static ServerMessage decodeLogin(int envelopeId, ByteBuf buf, Connection break; case AUTH_MORE_DATA: // Auth more data - return AuthMoreDataMessage.decode(envelopeId, buf); + return AuthMoreDataMessage.decode(buf); case HANDSHAKE_V9: case HANDSHAKE_V10: // Handshake V9 (not supported) or V10 - return HandshakeRequest.decode(envelopeId, buf); + return HandshakeRequest.decode(buf); case ERROR: // Error return ErrorMessage.decode(buf); case EOF: // Auth exchange message or EOF message @@ -257,7 +258,7 @@ private static ServerMessage decodeLogin(int envelopeId, ByteBuf buf, Connection return EofMessage.decode(buf); } - return ChangeAuthMessage.decode(envelopeId, buf); + return ChangeAuthMessage.decode(buf); } throw new R2dbcPermissionDeniedException("Unknown message header 0x" + @@ -265,32 +266,6 @@ private static ServerMessage decodeLogin(int envelopeId, ByteBuf buf, Connection " on connection phase"); } - @Nullable - private static Byte readNotFinish(List buffers, ByteBuf envelope) { - try { - int size = envelope.readUnsignedMediumLE(); - if (size < Envelopes.MAX_ENVELOPE_SIZE) { - Byte envelopeId = envelope.readByte(); - - buffers.add(envelope); - // success, no need release - envelope = null; - return envelopeId; - } - - // skip the sequence Id - envelope.skipBytes(1); - buffers.add(envelope); - // success, no need release - envelope = null; - return null; - } finally { - if (envelope != null) { - envelope.release(); - } - } - } - private static boolean isRow(List buffers, ByteBuf firstBuf, short header) { switch (header) { case RowMessage.NULL_VALUE: diff --git a/src/test/java/io/asyncer/r2dbc/mysql/CompressionIntegrationTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/CompressionIntegrationTestSupport.java new file mode 100644 index 000000000..3e7c5bdb7 --- /dev/null +++ b/src/test/java/io/asyncer/r2dbc/mysql/CompressionIntegrationTestSupport.java @@ -0,0 +1,88 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Base class for compression integration tests. + */ +abstract class CompressionIntegrationTestSupport extends IntegrationTestSupport { + + CompressionIntegrationTestSupport(CompressionAlgorithm algorithm) { + super(configuration(builder -> builder.compressionAlgorithms(algorithm))); + } + + @Test + void simpleQuery() { + byte[] hello = "Hello".getBytes(StandardCharsets.US_ASCII); + byte[] repeatedBytes = new byte[hello.length * 50]; + + for (int i = 0; i < 50; i++) { + System.arraycopy(hello, 0, repeatedBytes, i * hello.length, hello.length); + } + + String repeated = new String(repeatedBytes, StandardCharsets.US_ASCII); + + complete(connection -> connection.createStatement("SELECT REPEAT('Hello', 50)").execute() + .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, String.class))) + .collectList() + .doOnNext(actual -> assertThat(actual).isEqualTo(Collections.singletonList(repeated)))); + } + + @ParameterizedTest + @ValueSource(strings = { "stations", "users" }) + @SuppressWarnings("SqlSourceToSinkFlow") + void loadDataLocalInfile(String name) throws URISyntaxException, IOException { + URL tdlUrl = Objects.requireNonNull(getClass().getResource(String.format("/local/%s.sql", name))); + URL csvUrl = Objects.requireNonNull(getClass().getResource(String.format("/local/%s.csv", name))); + String tdl = new String(Files.readAllBytes(Paths.get(tdlUrl.toURI())), StandardCharsets.UTF_8); + String path = Paths.get(csvUrl.toURI()).toString(); + String loadData = String.format("LOAD DATA LOCAL INFILE '%s' INTO TABLE `%s` " + + "FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"'", path, name); + String select = String.format("SELECT * FROM `%s` ORDER BY `id`", name); + AtomicInteger count = new AtomicInteger(-1); + + complete(conn -> conn.createStatement(tdl) + .execute() + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .thenMany(conn.createStatement(loadData).execute()) + .flatMap(IntegrationTestSupport::extractRowsUpdated) + .reduce(0L, Long::sum) + .doOnNext(it -> count.set(it.intValue())) + .doOnNext(it -> assertThat(it).isGreaterThan(0)) + .thenMany(conn.createStatement(select).execute()) + .flatMap(result -> result.map(r -> 1)) + .reduce(0, Integer::sum) + .doOnNext(it -> assertThat(it).isEqualTo(count.get()))); + } +} diff --git a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java index ed66f34ab..48d63c0ed 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java @@ -53,11 +53,8 @@ */ class ConnectionIntegrationTest extends IntegrationTestSupport { - private static final MySqlConnectionConfiguration config = configuration( - "r2dbc", false, false, null, null); - ConnectionIntegrationTest() { - super(config); + super(configuration(builder -> builder)); } @Test diff --git a/src/test/java/io/asyncer/r2dbc/mysql/InitDbIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/InitDbIntegrationTest.java index afece8130..66fb46e5a 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/InitDbIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/InitDbIntegrationTest.java @@ -15,10 +15,7 @@ class InitDbIntegrationTest extends IntegrationTestSupport { private static final String DATABASE = "test-" + ThreadLocalRandom.current().nextInt(10000); InitDbIntegrationTest() { - super(configuration( - DATABASE, true, false, - null, null - )); + super(configuration(builder -> builder.database(DATABASE).createDatabaseIfNotExist(true))); } @Test diff --git a/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java index 2489d5732..fb83de493 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java @@ -19,7 +19,6 @@ import io.r2dbc.spi.R2dbcBadGrammarException; import io.r2dbc.spi.R2dbcTimeoutException; import io.r2dbc.spi.Result; -import org.jetbrains.annotations.Nullable; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -30,10 +29,8 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.time.Duration; -import java.time.ZoneId; import java.util.Objects; import java.util.function.Function; -import java.util.function.Predicate; import static org.assertj.core.api.Assertions.assertThat; @@ -82,8 +79,7 @@ static Mono extractRowsUpdated(Result result) { } static MySqlConnectionConfiguration configuration( - String database, boolean createDatabaseIfNotExist, boolean autodetectExtensions, - @Nullable ZoneId serverZoneId, @Nullable Predicate preferPrepared + Function customizer ) { String password = System.getProperty("test.mysql.password"); @@ -106,22 +102,10 @@ static MySqlConnectionConfiguration configuration( .connectTimeout(Duration.ofSeconds(3)) .user("root") .password(password) - .database(database) - .createDatabaseIfNotExist(createDatabaseIfNotExist) - .allowLoadLocalInfileInPath(localInfilePath) - .autodetectExtensions(autodetectExtensions); + .database("r2dbc") + .allowLoadLocalInfileInPath(localInfilePath); - if (serverZoneId != null) { - builder.serverZoneId(serverZoneId); - } - - if (preferPrepared == null) { - builder.useClientPrepareStatement(); - } else { - builder.useServerPrepareStatement(preferPrepared); - } - - return builder.build(); + return customizer.apply(builder).build(); } boolean envIsLessThanMySql56() { diff --git a/src/test/java/io/asyncer/r2dbc/mysql/JacksonPrepareIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/JacksonPrepareIntegrationTest.java index 7205a82bf..f78c5f11b 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/JacksonPrepareIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/JacksonPrepareIntegrationTest.java @@ -22,7 +22,8 @@ class JacksonPrepareIntegrationTest extends JacksonIntegrationTestSupport { JacksonPrepareIntegrationTest() { - super(configuration("r2dbc", false, true, null, sql -> false)); + super(configuration(builder -> builder.autodetectExtensions(true) + .useServerPrepareStatement())); } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/JacksonTextIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/JacksonTextIntegrationTest.java index 6d666e520..0b114e033 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/JacksonTextIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/JacksonTextIntegrationTest.java @@ -22,6 +22,6 @@ class JacksonTextIntegrationTest extends JacksonIntegrationTestSupport { JacksonTextIntegrationTest() { - super(configuration("r2dbc", false, true, null, null)); + super(configuration(builder -> builder.autodetectExtensions(true))); } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java index 04c5fac2e..8b08b3ead 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbIntegrationTestSupport.java @@ -17,11 +17,10 @@ package io.asyncer.r2dbc.mysql; import io.r2dbc.spi.Readable; -import org.jetbrains.annotations.Nullable; import org.junit.jupiter.api.Test; import java.time.ZonedDateTime; -import java.util.function.Predicate; +import java.util.function.Function; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -31,8 +30,10 @@ */ abstract class MariaDbIntegrationTestSupport extends IntegrationTestSupport { - MariaDbIntegrationTestSupport(@Nullable Predicate preferPrepared) { - super(configuration("r2dbc", false, false, null, preferPrepared)); + MariaDbIntegrationTestSupport( + Function customizer + ) { + super(configuration(customizer)); } @Test diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java index b7ac81a8b..8f7ba2998 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbPrepareIntegrationTest.java @@ -25,6 +25,6 @@ class MariaDbPrepareIntegrationTest extends MariaDbIntegrationTestSupport { MariaDbPrepareIntegrationTest() { - super(sql -> true); + super(builder -> builder.useServerPrepareStatement(sql -> true)); } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java index 0ab886c5f..fc285ddb3 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/MariaDbTextIntegrationTest.java @@ -25,6 +25,6 @@ class MariaDbTextIntegrationTest extends MariaDbIntegrationTestSupport { MariaDbTextIntegrationTest() { - super(null); + super(builder -> builder); } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java index aa3067a1e..e0baef7d0 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java @@ -16,6 +16,7 @@ package io.asyncer.r2dbc.mysql; +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.TlsVersions; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; @@ -240,6 +241,8 @@ private static MySqlConnectionConfiguration filledUp() { .sslKey("/path/to/mysql/client-key.pem") .sslKeyPassword("pem-password-in-here") .tlsVersion(TlsVersions.TLS1_1, TlsVersions.TLS1_2, TlsVersions.TLS1_3) + .compressionAlgorithms(CompressionAlgorithm.ZSTD, CompressionAlgorithm.ZLIB, + CompressionAlgorithm.UNCOMPRESSED) .serverZoneId(ZoneId.systemDefault()) .zeroDateOption(ZeroDateOption.USE_NULL) .sslHostnameVerifier((host, s) -> true) diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java b/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java index 17f515530..6c48a4e92 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java @@ -16,6 +16,7 @@ package io.asyncer.r2dbc.mysql; +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; import io.asyncer.r2dbc.mysql.constant.SslMode; import io.asyncer.r2dbc.mysql.constant.ZeroDateOption; import io.netty.handler.ssl.SslContextBuilder; @@ -24,6 +25,8 @@ import io.r2dbc.spi.Option; import org.assertj.core.api.Assert; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; @@ -34,6 +37,7 @@ import java.time.Duration; import java.time.ZoneId; import java.util.Collections; +import java.util.Set; import java.util.function.Function; import java.util.function.Predicate; @@ -177,8 +181,8 @@ void validProgrammaticHost() { @Test void invalidProgrammatic() { - assertThatIllegalStateException().isThrownBy(() -> - MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() + assertThatIllegalStateException() + .isThrownBy(() -> MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() .option(DRIVER, "mysql") .option(PORT, 3307) .option(USER, "root") @@ -198,8 +202,8 @@ void invalidProgrammatic() { .build())) .withMessageContaining("host"); - assertThatIllegalStateException().isThrownBy(() -> - MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() + assertThatIllegalStateException() + .isThrownBy(() -> MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() .option(DRIVER, "mysql") .option(HOST, "127.0.0.1") .option(PORT, 3307) @@ -207,8 +211,8 @@ void invalidProgrammatic() { .build())) .withMessageContaining("user"); - assertThatIllegalArgumentException().isThrownBy(() -> - MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() + assertThatIllegalArgumentException() + .isThrownBy(() -> MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() .option(DRIVER, "mysql") .option(HOST, "127.0.0.1") .option(PORT, 3307) @@ -220,8 +224,8 @@ void invalidProgrammatic() { .build())) .withMessageContaining("sslCert and sslKey"); - assertThatIllegalArgumentException().isThrownBy(() -> - MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() + assertThatIllegalArgumentException() + .isThrownBy(() -> MySqlConnectionFactoryProvider.setup(ConnectionFactoryOptions.builder() .option(DRIVER, "mysql") .option(HOST, "127.0.0.1") .option(PORT, 3307) @@ -394,6 +398,40 @@ void invalidServerPreparing() { .build())); } + @ParameterizedTest + @ValueSource(strings = { + "uncompressed", + "zlib", + "zstd", + "zlib,uncompressed", + "zstd,uncompressed", + "zstd,zlib", + "zstd,zlib,uncompressed", + }) + void validCompressionAlgorithms(String name) { + Set algorithms = MySqlConnectionFactoryProvider.setup( + ConnectionFactoryOptions.builder() + .option(DRIVER, "mysql") + .option(HOST, "127.0.0.1") + .option(USER, "root") + .option(Option.valueOf("compressionAlgorithms"), name) + .build()).getCompressionAlgorithms(); + + assertThat(algorithms).hasSize(name.split(",").length); + } + + @ParameterizedTest + @ValueSource(strings = { "", "gzip", "lz4", "lz4hc", "none", "snappy", "zlib,none", "zstd,none" }) + void invalidCompressionAlgorithms(String name) { + assertThatIllegalArgumentException().isThrownBy(() -> MySqlConnectionFactoryProvider.setup( + ConnectionFactoryOptions.builder() + .option(DRIVER, "mysql") + .option(HOST, "127.0.0.1") + .option(USER, "root") + .option(Option.valueOf("compressionAlgorithms"), name) + .build())); + } + @Test void validPasswordSupplier() { final Publisher passwordSupplier = Mono.just("123456"); diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MySqlPrepareTestKit.java b/src/test/java/io/asyncer/r2dbc/mysql/MySqlPrepareTestKit.java index 4a8878b92..08ed1fe28 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/MySqlPrepareTestKit.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/MySqlPrepareTestKit.java @@ -22,7 +22,8 @@ class MySqlPrepareTestKit extends MySqlTestKitSupport { MySqlPrepareTestKit() { - super(IntegrationTestSupport.configuration("r2dbc", false, false, null, sql -> true)); + super(IntegrationTestSupport.configuration(builder -> + builder.useServerPrepareStatement(sql -> true))); } @Override diff --git a/src/test/java/io/asyncer/r2dbc/mysql/MySqlTextTestKit.java b/src/test/java/io/asyncer/r2dbc/mysql/MySqlTextTestKit.java index 04c32c719..68fe276fb 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/MySqlTextTestKit.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/MySqlTextTestKit.java @@ -22,6 +22,6 @@ class MySqlTextTestKit extends MySqlTestKitSupport { MySqlTextTestKit() { - super(IntegrationTestSupport.configuration("r2dbc", false, false, null, null)); + super(IntegrationTestSupport.configuration(builder -> builder)); } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/PrepareQueryIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/PrepareQueryIntegrationTest.java index 45d5a94d7..3d53c2965 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/PrepareQueryIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/PrepareQueryIntegrationTest.java @@ -29,7 +29,7 @@ class PrepareQueryIntegrationTest extends QueryIntegrationTestSupport { PrepareQueryIntegrationTest() { - super(configuration("r2dbc", false, false, null, sql -> true)); + super(configuration(builder -> builder.useServerPrepareStatement(sql -> true))); } @Test diff --git a/src/test/java/io/asyncer/r2dbc/mysql/PrepareTimeZoneIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/PrepareTimeZoneIntegrationTest.java index ee6fbc391..6b53312e5 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/PrepareTimeZoneIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/PrepareTimeZoneIntegrationTest.java @@ -22,6 +22,6 @@ class PrepareTimeZoneIntegrationTest extends TimeZoneIntegrationTestSupport { PrepareTimeZoneIntegrationTest() { - super(sql -> true); + super(builder -> builder.useServerPrepareStatement(sql -> true)); } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/TextQueryIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/TextQueryIntegrationTest.java index a4e7152b7..76f3b95c6 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/TextQueryIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/TextQueryIntegrationTest.java @@ -22,6 +22,6 @@ class TextQueryIntegrationTest extends QueryIntegrationTestSupport { TextQueryIntegrationTest() { - super(configuration("r2dbc", false, false, null, null)); + super(configuration(builder -> builder)); } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/TextTimeZoneIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/TextTimeZoneIntegrationTest.java index 6b58ae1d4..336a0d5c1 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/TextTimeZoneIntegrationTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/TextTimeZoneIntegrationTest.java @@ -22,6 +22,6 @@ class TextTimeZoneIntegrationTest extends TimeZoneIntegrationTestSupport { TextTimeZoneIntegrationTest() { - super(null); + super(builder -> builder); } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/TimeZoneIntegrationTestSupport.java b/src/test/java/io/asyncer/r2dbc/mysql/TimeZoneIntegrationTestSupport.java index 5bfb769f1..7e32d07e4 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/TimeZoneIntegrationTestSupport.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/TimeZoneIntegrationTestSupport.java @@ -16,7 +16,6 @@ package io.asyncer.r2dbc.mysql; -import org.jetbrains.annotations.Nullable; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; @@ -31,7 +30,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.TimeZone; -import java.util.function.Predicate; +import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; @@ -65,8 +64,10 @@ abstract class TimeZoneIntegrationTestSupport extends IntegrationTestSupport { .isEqualTo(DST.atZone(SERVER_ZONE).plusHours(1)); } - TimeZoneIntegrationTestSupport(@Nullable Predicate preferPrepared) { - super(configuration("r2dbc", false, false, SERVER_ZONE, preferPrepared)); + TimeZoneIntegrationTestSupport( + Function customizer + ) { + super(configuration(builder -> customizer.apply(builder.serverZoneId(SERVER_ZONE)))); } @Test diff --git a/src/main/java/io/asyncer/r2dbc/mysql/client/Lifecycle.java b/src/test/java/io/asyncer/r2dbc/mysql/ZlibCompressionIntegrationTest.java similarity index 62% rename from src/main/java/io/asyncer/r2dbc/mysql/client/Lifecycle.java rename to src/test/java/io/asyncer/r2dbc/mysql/ZlibCompressionIntegrationTest.java index 2ad0f6efd..df0e3c639 100644 --- a/src/main/java/io/asyncer/r2dbc/mysql/client/Lifecycle.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/ZlibCompressionIntegrationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 asyncer.io projects + * Copyright 2024 asyncer.io projects * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,16 +14,16 @@ * limitations under the License. */ -package io.asyncer.r2dbc.mysql.client; +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; /** - * The lifecycle of connection. + * Integration tests for zstd compression. */ -enum Lifecycle { - -// CONNECTION, // Useless for signal - - COMMAND, +class ZlibCompressionIntegrationTest extends CompressionIntegrationTestSupport { -// REPLICATION // Useless for r2dbc driver, just ignore + ZlibCompressionIntegrationTest() { + super(CompressionAlgorithm.ZLIB); + } } diff --git a/src/test/java/io/asyncer/r2dbc/mysql/ZstdCompressionIntegrationTest.java b/src/test/java/io/asyncer/r2dbc/mysql/ZstdCompressionIntegrationTest.java new file mode 100644 index 000000000..8a0c339e7 --- /dev/null +++ b/src/test/java/io/asyncer/r2dbc/mysql/ZstdCompressionIntegrationTest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql; + +import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm; +import org.junit.jupiter.api.condition.EnabledIf; + +/** + * Integration tests for zstd compression. + */ +@EnabledIf("envIsZstdSupported") +class ZstdCompressionIntegrationTest extends CompressionIntegrationTestSupport { + + ZstdCompressionIntegrationTest() { + super(CompressionAlgorithm.ZSTD); + } + + static boolean envIsZstdSupported() { + String type = System.getProperty("test.db.type"); + + if ("mariadb".equalsIgnoreCase(type)) { + return false; + } + + String version = System.getProperty("test.mysql.version"); + + if (version == null || version.isEmpty()) { + return true; + } + + ServerVersion ver = ServerVersion.parse(version); + return ver.isGreaterThanOrEqualTo(ServerVersion.create(8, 0, 18)); + } +} diff --git a/src/test/java/io/asyncer/r2dbc/mysql/client/ZlibCompressorTest.java b/src/test/java/io/asyncer/r2dbc/mysql/client/ZlibCompressorTest.java new file mode 100644 index 000000000..9ae947959 --- /dev/null +++ b/src/test/java/io/asyncer/r2dbc/mysql/client/ZlibCompressorTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2024 asyncer.io projects + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.asyncer.r2dbc.mysql.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.DecoderException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.stream.Stream; +import java.util.zip.DeflaterOutputStream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * Unit tests for {@link ZlibCompressor}. + */ +class ZlibCompressorTest { + + private final ZlibCompressor compressor = new ZlibCompressor(); + + @ParameterizedTest + @MethodSource("uncompressedData") + void compress(String input) throws IOException { + byte[] bytes = input.getBytes(StandardCharsets.UTF_8); + ByteBuf compressed = compressor.compress(Unpooled.wrappedBuffer(bytes)); + byte[] nativeCompressed = nativeCompress(bytes); + + // It may return early if the compressed data is not smaller than the original. + assertThat(ByteBufUtil.getBytes(compressed)).hasSizeLessThanOrEqualTo(bytes.length) + .isEqualTo(Arrays.copyOf(nativeCompressed, compressed.readableBytes())); + } + + @ParameterizedTest + @MethodSource("uncompressedData") + void decompress(String input) throws IOException { + byte[] bytes = input.getBytes(StandardCharsets.UTF_8); + ByteBuf compressed = Unpooled.wrappedBuffer(nativeCompress(bytes)); + ByteBuf decompressed = compressor.decompress(compressed, bytes.length); + + assertThat(ByteBufUtil.getBytes(decompressed)).isEqualTo(bytes); + } + + @Test + void badDecompress() { + ByteBuf compressed = Unpooled.wrappedBuffer( + new byte[] { 0x78, 0x7c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01 }); + + assertThatExceptionOfType(DecoderException.class) + .isThrownBy(() -> compressor.decompress(compressed, compressed.readableBytes() << 1)); + } + + static Stream uncompressedData() { + return Stream.of( + "", " ", + "Hello, world!", + "1234567890", + "ユニコードテスト、유니코드 테스트,Unicode测试,тест Юникода", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis exercitation ullamco nisi ut aliquip ea commodo consequat. " + + "Duis aute irure dolor in reprehenderit en voluptate esse cillum eu fugiat nulla pariatur." + ); + } + + private static byte[] nativeCompress(byte[] input) throws IOException { + try (ByteArrayOutputStream r = new ByteArrayOutputStream(); + DeflaterOutputStream s = new DeflaterOutputStream(r)) { + + s.write(input); + s.finish(); + s.flush(); + + return r.toByteArray(); + } + } +} diff --git a/src/test/java/io/asyncer/r2dbc/mysql/internal/util/FluxEnvelopeTest.java b/src/test/java/io/asyncer/r2dbc/mysql/internal/util/FluxEnvelopeTest.java index 02e076a03..8c9d7325e 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/internal/util/FluxEnvelopeTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/internal/util/FluxEnvelopeTest.java @@ -16,7 +16,7 @@ package io.asyncer.r2dbc.mysql.internal.util; -import io.asyncer.r2dbc.mysql.constant.Envelopes; +import io.asyncer.r2dbc.mysql.constant.Packets; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; @@ -30,6 +30,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -262,7 +263,7 @@ void mergeIntegralWithLargeCrossIntegral() { } private Flux envelopes(Flux source, int envelopeSize) { - return new FluxEnvelope(source, allocator, envelopeSize, 0, true); + return new FluxEnvelope(source, allocator, envelopeSize, new AtomicInteger(0), true); } private Consumer> assertBuffers(String origin, int envelopeSize, int lastSize, @@ -273,7 +274,7 @@ private Consumer> assertBuffers(String origin, int envelopeSize, i for (int i = 0, n = originBuffers.size(); i < n; i += 2) { ByteBuf header = originBuffers.get(i); - assertThat(header.readableBytes()).isEqualTo(Envelopes.PART_HEADER_SIZE); + assertThat(header.readableBytes()).isEqualTo(Packets.NORMAL_HEADER_SIZE); int size = header.readMediumLE(); if (size > 0) { diff --git a/src/test/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoderTest.java b/src/test/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoderTest.java index c27dac59c..dd47e4678 100644 --- a/src/test/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoderTest.java +++ b/src/test/java/io/asyncer/r2dbc/mysql/message/server/ServerMessageDecoderTest.java @@ -21,7 +21,10 @@ import io.netty.buffer.Unpooled; import org.assertj.core.api.AbstractObjectAssert; import org.jetbrains.annotations.Nullable; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; @@ -30,21 +33,21 @@ */ class ServerMessageDecoderTest { - @Test - void okAndPreparedOk() { - AbstractObjectAssert ok = assertThat(decode(okLike(), DecodeContext.command())) - .isExactlyInstanceOf(OkMessage.class) - .extracting(message -> (OkMessage) message); + @ParameterizedTest + @MethodSource(value = { "okLikePayload" }) + void okAndPreparedOk(byte[] okLike) { + AbstractObjectAssert ok = assertThat(decode( + Unpooled.wrappedBuffer(okLike), DecodeContext.command() + )).isExactlyInstanceOf(OkMessage.class).extracting(message -> (OkMessage) message); ok.extracting(OkMessage::getAffectedRows).isEqualTo(1L); ok.extracting(OkMessage::getLastInsertId).isEqualTo(0x10000L); // 65536 ok.extracting(OkMessage::getServerStatuses).isEqualTo((short) 0x100); // 256 ok.extracting(OkMessage::getWarnings).isEqualTo(0); - AbstractObjectAssert preparedOk = assertThat(decode(okLike(), - DecodeContext.prepareQuery())) - .isExactlyInstanceOf(PreparedOkMessage.class) - .extracting(message -> (PreparedOkMessage) message); + AbstractObjectAssert preparedOk = assertThat(decode( + Unpooled.wrappedBuffer(okLike), DecodeContext.prepareQuery() + )).isExactlyInstanceOf(PreparedOkMessage.class).extracting(message -> (PreparedOkMessage) message); preparedOk.extracting(PreparedOkMessage::getStatementId).isEqualTo(0xFD01); // 64769 preparedOk.extracting(PreparedOkMessage::getTotalColumns).isEqualTo(1); @@ -56,10 +59,8 @@ private static ServerMessage decode(ByteBuf buf, DecodeContext decodeContext) { return new ServerMessageDecoder().decode(buf, ConnectionContextTest.mock(), decodeContext); } - private static ByteBuf okLike() { - return Unpooled.wrappedBuffer(new byte[] { - 10, 0, 0, // envelope size - 1, // sequence ID + static Stream okLikePayload() { + return Stream.of(new byte[] { 0, // Heading both of OK and Prepared OK 1, // OK: affected rows, Prepared OK: first byte of statement ID (byte) 0xFD,