diff --git a/rsocket-transport-netty/rsocket-transport-netty-websocket/build.gradle.kts b/rsocket-transport-netty/rsocket-transport-netty-websocket/build.gradle.kts index 2c7f381f..0519289b 100644 --- a/rsocket-transport-netty/rsocket-transport-netty-websocket/build.gradle.kts +++ b/rsocket-transport-netty/rsocket-transport-netty-websocket/build.gradle.kts @@ -29,6 +29,11 @@ kotlin { api(libs.netty.codec.http) } } + jvmTest { + dependencies { + implementation(libs.bouncycastle) + } + } } } diff --git a/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt new file mode 100644 index 00000000..41b8d6df --- /dev/null +++ b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt @@ -0,0 +1,109 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import java.net.* +import kotlin.coroutines.* + + +internal class NettyWebSocketChannelHandler( + private val bufferPool: ObjectPool, + private val sslContext: SslContext?, + private val remoteAddress: SocketAddress?, + private val httpHandler: ChannelHandler, + private val webSocketHandler: ChannelHandler, +) : ChannelInitializer() { + private val frames = channelForCloseable(Channel.UNLIMITED) + private val handshakeDeferred = CompletableDeferred() + + @RSocketTransportApi + suspend fun connect( + context: CoroutineContext, + channel: DuplexChannel, + ): NettyWebSocketConnection { + handshakeDeferred.await() + + return NettyWebSocketConnection( + coroutineContext = context.childContext(), + bufferPool = bufferPool, + channel = channel, + frames = frames + ) + } + + override fun initChannel(ch: DuplexChannel): Unit = with(ch.pipeline()) { +// addFirst(LoggingHandler(LogLevel.INFO)) + if (sslContext != null) { + val sslHandler = if ( + remoteAddress is InetSocketAddress && + ch.parent() == null // not server + ) { + sslContext.newHandler(ch.alloc(), remoteAddress.hostName, remoteAddress.port) + } else { + sslContext.newHandler(ch.alloc()) + } + addLast("ssl", sslHandler) + } + addLast("http", httpHandler) + addLast(HttpObjectAggregator(65536)) //TODO size? + addLast("websocket", webSocketHandler) + + addLast( + "rsocket-frame-receiver", + IncomingFramesChannelHandler() + ) + } + + private inner class IncomingFramesChannelHandler : SimpleChannelInboundHandler() { + override fun channelInactive(ctx: ChannelHandlerContext) { + frames.close() //TODO? + super.channelInactive(ctx) + } + + override fun channelRead0(ctx: ChannelHandlerContext, msg: WebSocketFrame) { + if (msg !is BinaryWebSocketFrame && msg !is TextWebSocketFrame) { + error("wrong frame type") + } + + frames.trySend(buildPacket { + writeFully(msg.content().nioBuffer()) + }).getOrThrow() + } + + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { + if ( + evt is WebSocketServerProtocolHandler.HandshakeComplete || + evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE + ) { + handshakeDeferred.complete(Unit) + } + //TODO: handle timeout + } + } +} diff --git a/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt new file mode 100644 index 00000000..4deb82d8 --- /dev/null +++ b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt @@ -0,0 +1,230 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyWebSocketClientTransport : RSocketClientTransport { + public val config: WebSocketClientProtocolConfig + + public sealed interface Engine : RSocketTransportEngine { + public fun createTransport( + target: WebSocketClientProtocolConfig.Builder.() -> Unit, + ): NettyWebSocketClientTransport { + return createTransport(WebSocketClientProtocolConfig.newBuilder().apply(target).build()) + } + } + + public sealed interface Builder { + public fun bufferPool(pool: ObjectPool) + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: Bootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) + public fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit) + } + + public companion object Factory : + RSocketTransportFactory { + public operator fun invoke( + context: CoroutineContext, + target: WebSocketClientProtocolConfig.Builder.() -> Unit, + block: Builder.() -> Unit, + ): NettyWebSocketClientTransport { + return invoke(context, WebSocketClientProtocolConfig.newBuilder().apply(target).build(), block) + } + + override fun invoke( + context: CoroutineContext, + target: WebSocketClientProtocolConfig, + block: Builder.() -> Unit, + ): NettyWebSocketClientTransport { + return NettyWebSocketClientTransportBuilderImpl().apply(block).build(context).buildTransport(target) + } + + override fun Engine(context: CoroutineContext, block: Builder.() -> Unit): Engine { + return NettyWebSocketClientTransportBuilderImpl().apply(block).build(context).buildEngine() + } + } +} + +private class NettyWebSocketClientTransportBuilderImpl : NettyWebSocketClientTransport.Builder { + private var bufferPool: ObjectPool = ChunkBuffer.Pool + private var channelFactory: ChannelFactory? = null + private var eventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (Bootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + private var webSockets: (WebSocketClientProtocolConfig.Builder.() -> Unit)? = null + + override fun bufferPool(pool: ObjectPool) { + this.bufferPool = pool + } + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.eventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: Bootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + override fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit) { + webSockets = block + } + + fun build(context: CoroutineContext): NettyWebSocketClientTransportResources { + val group = eventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioSocketChannel::class.java) + + val transportContext = group.asCoroutineDispatcher() + context.supervisorContext() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + group.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forClient() + .apply(it) + .build() + } + + val bootstrap = Bootstrap().apply { + bootstrap?.invoke(this) + group(group) + channelFactory(factory) + } + + return NettyWebSocketClientTransportResources( + coroutineContext = transportContext, + bufferPool = bufferPool, + sslContext = sslContext, + bootstrap = bootstrap, + webSocketConfig = webSockets + ) + } +} + +private class NettyWebSocketClientTransportResources( + private val coroutineContext: CoroutineContext, + private val bufferPool: ObjectPool, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, + private val webSocketConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?, +) { + fun buildTransport(config: WebSocketClientProtocolConfig): NettyWebSocketClientTransport { + return NettyWebSocketClientTransportImpl( + coroutineContext = coroutineContext, + config = config.toBuilder().apply { + webSocketConfig?.invoke(this) + }.build(), + bufferPool = bufferPool, + sslContext = sslContext, + bootstrap = bootstrap, + ) + } + + fun buildEngine(): NettyWebSocketClientTransport.Engine { + return NettyWebSocketClientTransportEngineImpl( + coroutineContext = coroutineContext, + bufferPool = bufferPool, + sslContext = sslContext, + bootstrap = bootstrap, + webSocketConfig = webSocketConfig, + ) + } +} + +private class NettyWebSocketClientTransportEngineImpl( + override val coroutineContext: CoroutineContext, + private val bufferPool: ObjectPool, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, + private val webSocketConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?, +) : NettyWebSocketClientTransport.Engine { + override fun createTransport(target: WebSocketClientProtocolConfig): NettyWebSocketClientTransport { + return NettyWebSocketClientTransportImpl( + coroutineContext = coroutineContext.supervisorContext(), + config = target.toBuilder().apply { + webSocketConfig?.invoke(this) + }.build(), + bufferPool = bufferPool, + sslContext = sslContext, + bootstrap = bootstrap, + ) + } +} + +private class NettyWebSocketClientTransportImpl( + override val coroutineContext: CoroutineContext, + override val config: WebSocketClientProtocolConfig, + private val bufferPool: ObjectPool, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, +) : NettyWebSocketClientTransport { + @RSocketTransportApi + override suspend fun connect(): RSocketTransportConnection { + val remoteAddress = InetSocketAddress(config.webSocketUri().host, config.webSocketUri().port) + val handler = NettyWebSocketChannelHandler( + bufferPool = bufferPool, + sslContext = sslContext, + remoteAddress = remoteAddress, + httpHandler = HttpClientCodec(), + webSocketHandler = WebSocketClientProtocolHandler(config), + ) + val future = bootstrap.clone().apply { + remoteAddress(remoteAddress) + handler(handler) + }.connect() + + future.awaitFuture() + + return handler.connect(coroutineContext, future.channel() as DuplexChannel) + } +} diff --git a/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketConnection.kt b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketConnection.kt new file mode 100644 index 00000000..024bd900 --- /dev/null +++ b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketConnection.kt @@ -0,0 +1,48 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* +import io.netty.buffer.* +import io.netty.channel.socket.* +import io.netty.handler.codec.http.websocketx.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.channels.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyWebSocketConnection( + override val coroutineContext: CoroutineContext, + override val bufferPool: ObjectPool, + private val channel: DuplexChannel, + private val frames: ReceiveChannel, +) : RSocketTransportConnection.Sequential { + + init { + linkCompletionWith(channel) + } + + override suspend fun sendFrame(frame: ByteReadPacket) { + channel.writeAndFlush(BinaryWebSocketFrame(Unpooled.wrappedBuffer(frame.readByteBuffer()))).awaitFuture() + } + + override suspend fun receiveFrame(): ByteReadPacket { + return frames.receive() + } +} diff --git a/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt new file mode 100644 index 00000000..7f085b6b --- /dev/null +++ b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt @@ -0,0 +1,324 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import javax.net.ssl.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyWebSocketServerInstance : RSocketServerInstance { + public val localAddress: SocketAddress + public val config: WebSocketServerProtocolConfig +} + +public class NettyWebSocketServerTarget( + public val localAddress: SocketAddress?, + public val config: WebSocketServerProtocolConfig?, +) + +public sealed interface NettyWebSocketServerTransport : RSocketServerTransport { + public val localAddress: SocketAddress? + public val config: WebSocketServerProtocolConfig + + public sealed interface Engine : RSocketTransportEngine { + //TODO!!! + public fun createTransport( + hostname: String = "0.0.0.0", + port: Int = 0, + config: WebSocketServerProtocolConfig? = null, + ): NettyWebSocketServerTransport = createTransport( + NettyWebSocketServerTarget( + InetSocketAddress(hostname, port), + config + ) + ) + + public fun createTransport( + hostname: String = "0.0.0.0", + port: Int = 0, + config: WebSocketServerProtocolConfig.Builder.() -> Unit, + ): NettyWebSocketServerTransport = createTransport( + NettyWebSocketServerTarget( + InetSocketAddress(hostname, port), + WebSocketServerProtocolConfig.newBuilder().apply(config).build() + ) + ) + + public fun createTransport( + config: WebSocketServerProtocolConfig.Builder.() -> Unit, + ): NettyWebSocketServerTransport = createTransport( + NettyWebSocketServerTarget( + null, + WebSocketServerProtocolConfig.newBuilder().apply(config).build() + ) + ) + + public fun createTransport(): NettyWebSocketServerTransport = createTransport(NettyWebSocketServerTarget(null, null)) + } + + public sealed interface Builder { + public fun bufferPool(pool: ObjectPool) + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: ServerBootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) + public fun webSockets(block: WebSocketServerProtocolConfig.Builder.() -> Unit) + } + + public companion object Factory : RSocketTransportFactory { + //TODO!!! + public operator fun invoke( + context: CoroutineContext, + hostname: String = "0.0.0.0", + port: Int = 0, + block: Builder.() -> Unit = {}, + ): NettyWebSocketServerTransport = invoke(context, NettyWebSocketServerTarget(InetSocketAddress(hostname, port), null), block) + + public operator fun invoke(context: CoroutineContext, block: Builder.() -> Unit = {}): NettyWebSocketServerTransport = + invoke(context, NettyWebSocketServerTarget(null, null), block) + + override fun invoke( + context: CoroutineContext, + target: NettyWebSocketServerTarget, + block: Builder.() -> Unit, + ): NettyWebSocketServerTransport { + return NettyWebSocketServerTransportBuilderImpl().apply(block).build(context).buildTransport(target) + } + + override fun Engine(context: CoroutineContext, block: Builder.() -> Unit): Engine { + return NettyWebSocketServerTransportBuilderImpl().apply(block).build(context).buildEngine() + } + } +} + +private class NettyWebSocketServerTransportBuilderImpl : NettyWebSocketServerTransport.Builder { + private var bufferPool: ObjectPool = ChunkBuffer.Pool + private var channelFactory: ChannelFactory? = null + private var parentEventLoopGroup: EventLoopGroup? = null + private var childEventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (ServerBootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + private var webSockets: (WebSocketServerProtocolConfig.Builder.() -> Unit)? = null + + override fun bufferPool(pool: ObjectPool) { + this.bufferPool = pool + } + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = parentGroup + this.childEventLoopGroup = childGroup + this.manageEventLoopGroup = manage + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = group + this.childEventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: ServerBootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + override fun webSockets(block: WebSocketServerProtocolConfig.Builder.() -> Unit) { + webSockets = block + } + + fun build(context: CoroutineContext): NettyWebSocketServerTransportResources { + val parentGroup = parentEventLoopGroup ?: NioEventLoopGroup() + val childGroup = childEventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioServerSocketChannel::class.java) + + val transportContext = parentGroup.asCoroutineDispatcher() + context.supervisorContext() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + childGroup.shutdownGracefully().awaitFuture() + parentGroup.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forServer(KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())) + .apply(it) + .build() + } + + val bootstrap = ServerBootstrap().apply { + bootstrap?.invoke(this) + group(parentGroup, childGroup) + channelFactory(factory) + } + + return NettyWebSocketServerTransportResources( + coroutineContext = transportContext, + bufferPool = bufferPool, + bootstrap = bootstrap, + sslContext = sslContext, + webSocketConfig = webSockets + ) + } +} + +private class NettyWebSocketServerTransportResources( + private val coroutineContext: CoroutineContext, + private val bufferPool: ObjectPool, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, + private val webSocketConfig: (WebSocketServerProtocolConfig.Builder.() -> Unit)?, +) { + fun buildTransport(target: NettyWebSocketServerTarget): NettyWebSocketServerTransport { + return NettyWebSocketServerTransportImpl( + coroutineContext = coroutineContext, + localAddress = target.localAddress, + config = (target.config?.toBuilder() ?: WebSocketServerProtocolConfig.newBuilder()).apply { + webSocketConfig?.invoke(this) + }.build(), + bufferPool = bufferPool, + bootstrap = bootstrap, + sslContext = sslContext + ) + } + + fun buildEngine(): NettyWebSocketServerTransport.Engine { + return NettyWebSocketServerTransportEngineImpl( + coroutineContext = coroutineContext, + bufferPool = bufferPool, + bootstrap = bootstrap, + sslContext = sslContext, + webSocketConfig = webSocketConfig + ) + } +} + +private class NettyWebSocketServerTransportEngineImpl( + override val coroutineContext: CoroutineContext, + private val bufferPool: ObjectPool, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, + private val webSocketConfig: (WebSocketServerProtocolConfig.Builder.() -> Unit)?, +) : NettyWebSocketServerTransport.Engine { + override fun createTransport(target: NettyWebSocketServerTarget): NettyWebSocketServerTransport { + return NettyWebSocketServerTransportImpl( + coroutineContext = coroutineContext.supervisorContext(), + localAddress = target.localAddress, + config = (target.config?.toBuilder() ?: WebSocketServerProtocolConfig.newBuilder()).apply { + webSocketConfig?.invoke(this) + }.build(), + bufferPool = bufferPool, + bootstrap = bootstrap, + sslContext = sslContext + ) + } +} + +private class NettyWebSocketServerTransportImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: SocketAddress?, + override val config: WebSocketServerProtocolConfig, + private val bufferPool: ObjectPool, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, +) : NettyWebSocketServerTransport { + + @RSocketTransportApi + override suspend fun bind(acceptor: RSocketServerAcceptor): NettyWebSocketServerInstance { + val instanceContext = coroutineContext.supervisorContext() + try { + val future = bootstrap.clone().apply { + localAddress(localAddress ?: InetSocketAddress(0)) + childHandler(AcceptorChannelHandler(instanceContext, bufferPool, sslContext, acceptor, config)) + }.bind() + + future.awaitFuture() + + return NettyWebSocketServerInstanceImpl( + coroutineContext = instanceContext, + channel = future.channel() as ServerChannel, + config = config + ) + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + } +} + +@RSocketTransportApi +private class AcceptorChannelHandler( + override val coroutineContext: CoroutineContext, + private val bufferPool: ObjectPool, + private val sslContext: SslContext?, + private val acceptor: RSocketServerAcceptor, + private val config: WebSocketServerProtocolConfig, +) : ChannelInitializer(), CoroutineScope { + override fun initChannel(ch: DuplexChannel) { + val handler = NettyWebSocketChannelHandler( + bufferPool = bufferPool, + sslContext = sslContext, + remoteAddress = null, + httpHandler = HttpServerCodec(), + webSocketHandler = WebSocketServerProtocolHandler(config) + ) + ch.pipeline().addLast(handler) + launch { + acceptor.accept(handler.connect(coroutineContext, ch)) + } + } +} + +private class NettyWebSocketServerInstanceImpl( + override val coroutineContext: CoroutineContext, + private val channel: ServerChannel, + override val config: WebSocketServerProtocolConfig, +) : NettyWebSocketServerInstance { + override val localAddress: SocketAddress get() = channel.localAddress() + + init { + linkCompletionWith(channel) + } +} diff --git a/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/utils.kt b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/utils.kt new file mode 100644 index 00000000..fc245165 --- /dev/null +++ b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/utils.kt @@ -0,0 +1,59 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.channel.* +import io.netty.util.concurrent.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@Suppress("UNCHECKED_CAST") +internal suspend inline fun Future.awaitFuture(): T = suspendCancellableCoroutine { cont -> + addListener { + when { + it.isSuccess -> cont.resume(it.now as T) + else -> cont.resumeWithException(it.cause()) + } + } + cont.invokeOnCancellation { + cancel(true) + } +} + +internal fun CoroutineScope.linkCompletionWith(channel: Channel) { + channel.closeFuture().addListener { + cancel("Netty channel closed", it.cause()) + } + invokeOnCancellation { + channel.close().awaitFuture() + } +} + +internal inline fun CoroutineScope.invokeOnCancellation( + crossinline block: suspend () -> Unit, +) { + launch { + try { + awaitCancellation() + } catch (cause: Throwable) { + withContext(NonCancellable) { + runCatching { block() }.onFailure { cause.addSuppressed(it) } + } + throw cause + } + } +} diff --git a/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt new file mode 100644 index 00000000..481d5973 --- /dev/null +++ b/rsocket-transport-netty/rsocket-transport-netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt @@ -0,0 +1,77 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.channel.nio.* +import io.netty.handler.ssl.util.* +import io.rsocket.kotlin.test.* +import io.rsocket.kotlin.transport.tests.* +import java.net.* +import kotlin.concurrent.* + +private val eventLoop = NioEventLoopGroup().also { + Runtime.getRuntime().addShutdownHook(thread(start = false) { + it.shutdownGracefully().await(1000) + }) +} +private val certificates = SelfSignedCertificate() + +class NettyWebSocketTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyWebSocketServerTransport(testContext) { + bufferPool(InUseTrackingPool) + eventLoopGroup(eventLoop, manage = false) + } + ) + client = connectClient( + NettyWebSocketClientTransport(testContext, { + val address = server.localAddress as InetSocketAddress + webSocketUri("ws://localhost:${address.port}") + }) { + bufferPool(InUseTrackingPool) + eventLoopGroup(eventLoop, manage = false) + } + ) + } +} + +class NettyWebSocketSslTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyWebSocketServerTransport(testContext) { + bufferPool(InUseTrackingPool) + eventLoopGroup(eventLoop, manage = false) + ssl { + keyManager(certificates.certificate(), certificates.privateKey()) + } + } + ) + client = connectClient( + NettyWebSocketClientTransport(testContext, { + val address = server.localAddress as InetSocketAddress + webSocketUri("ws://localhost:${address.port}") + }) { + bufferPool(InUseTrackingPool) + eventLoopGroup(eventLoop, manage = false) + ssl { + trustManager(InsecureTrustManagerFactory.INSTANCE) + } + } + ) + } +}