Skip to content

Commit

Permalink
first version of netty WS transport
Browse files Browse the repository at this point in the history
  • Loading branch information
whyoleg committed Jun 21, 2023
1 parent 55e09d0 commit 8da1a8b
Show file tree
Hide file tree
Showing 7 changed files with 852 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ kotlin {
api(libs.netty.codec.http)
}
}
jvmTest {
dependencies {
implementation(libs.bouncycastle)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright 2015-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.netty.websocket

import io.ktor.utils.io.core.*
import io.ktor.utils.io.core.internal.*
import io.ktor.utils.io.pool.*
import io.netty.channel.*
import io.netty.channel.socket.*
import io.netty.handler.codec.http.*
import io.netty.handler.codec.http.websocketx.*
import io.netty.handler.ssl.*
import io.rsocket.kotlin.internal.io.*
import io.rsocket.kotlin.transport.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import java.net.*
import kotlin.coroutines.*


internal class NettyWebSocketChannelHandler(
private val bufferPool: ObjectPool<ChunkBuffer>,
private val sslContext: SslContext?,
private val remoteAddress: SocketAddress?,
private val httpHandler: ChannelHandler,
private val webSocketHandler: ChannelHandler,
) : ChannelInitializer<DuplexChannel>() {
private val frames = channelForCloseable<ByteReadPacket>(Channel.UNLIMITED)
private val handshakeDeferred = CompletableDeferred<Unit>()

@RSocketTransportApi
suspend fun connect(
context: CoroutineContext,
channel: DuplexChannel,
): NettyWebSocketConnection {
handshakeDeferred.await()

return NettyWebSocketConnection(
coroutineContext = context.childContext(),
bufferPool = bufferPool,
channel = channel,
frames = frames
)
}

override fun initChannel(ch: DuplexChannel): Unit = with(ch.pipeline()) {
// addFirst(LoggingHandler(LogLevel.INFO))
if (sslContext != null) {
val sslHandler = if (
remoteAddress is InetSocketAddress &&
ch.parent() == null // not server
) {
sslContext.newHandler(ch.alloc(), remoteAddress.hostName, remoteAddress.port)
} else {
sslContext.newHandler(ch.alloc())
}
addLast("ssl", sslHandler)
}
addLast("http", httpHandler)
addLast(HttpObjectAggregator(65536)) //TODO size?
addLast("websocket", webSocketHandler)

addLast(
"rsocket-frame-receiver",
IncomingFramesChannelHandler()
)
}

private inner class IncomingFramesChannelHandler : SimpleChannelInboundHandler<WebSocketFrame>() {
override fun channelInactive(ctx: ChannelHandlerContext) {
frames.close() //TODO?
super.channelInactive(ctx)
}

override fun channelRead0(ctx: ChannelHandlerContext, msg: WebSocketFrame) {
if (msg !is BinaryWebSocketFrame && msg !is TextWebSocketFrame) {
error("wrong frame type")
}

frames.trySend(buildPacket {
writeFully(msg.content().nioBuffer())
}).getOrThrow()
}

override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
if (
evt is WebSocketServerProtocolHandler.HandshakeComplete ||
evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE
) {
handshakeDeferred.complete(Unit)
}
//TODO: handle timeout
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Copyright 2015-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.netty.websocket

import io.ktor.utils.io.core.internal.*
import io.ktor.utils.io.pool.*
import io.netty.bootstrap.*
import io.netty.channel.*
import io.netty.channel.ChannelFactory
import io.netty.channel.nio.*
import io.netty.channel.socket.*
import io.netty.channel.socket.nio.*
import io.netty.handler.codec.http.*
import io.netty.handler.codec.http.websocketx.*
import io.netty.handler.ssl.*
import io.rsocket.kotlin.internal.io.*
import io.rsocket.kotlin.transport.*
import kotlinx.coroutines.*
import java.net.*
import kotlin.coroutines.*
import kotlin.reflect.*

public sealed interface NettyWebSocketClientTransport : RSocketClientTransport {
public val config: WebSocketClientProtocolConfig

public sealed interface Engine : RSocketTransportEngine<WebSocketClientProtocolConfig, NettyWebSocketClientTransport> {
public fun createTransport(
target: WebSocketClientProtocolConfig.Builder.() -> Unit,
): NettyWebSocketClientTransport {
return createTransport(WebSocketClientProtocolConfig.newBuilder().apply(target).build())
}
}

public sealed interface Builder {
public fun bufferPool(pool: ObjectPool<ChunkBuffer>)

public fun channel(cls: KClass<out DuplexChannel>)
public fun channelFactory(factory: ChannelFactory<out DuplexChannel>)
public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean)

public fun bootstrap(block: Bootstrap.() -> Unit)
public fun ssl(block: SslContextBuilder.() -> Unit)
public fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit)
}

public companion object Factory :
RSocketTransportFactory<WebSocketClientProtocolConfig, NettyWebSocketClientTransport, Engine, Builder> {
public operator fun invoke(
context: CoroutineContext,
target: WebSocketClientProtocolConfig.Builder.() -> Unit,
block: Builder.() -> Unit,
): NettyWebSocketClientTransport {
return invoke(context, WebSocketClientProtocolConfig.newBuilder().apply(target).build(), block)
}

override fun invoke(
context: CoroutineContext,
target: WebSocketClientProtocolConfig,
block: Builder.() -> Unit,
): NettyWebSocketClientTransport {
return NettyWebSocketClientTransportBuilderImpl().apply(block).build(context).buildTransport(target)
}

override fun Engine(context: CoroutineContext, block: Builder.() -> Unit): Engine {
return NettyWebSocketClientTransportBuilderImpl().apply(block).build(context).buildEngine()
}
}
}

private class NettyWebSocketClientTransportBuilderImpl : NettyWebSocketClientTransport.Builder {
private var bufferPool: ObjectPool<ChunkBuffer> = ChunkBuffer.Pool
private var channelFactory: ChannelFactory<out DuplexChannel>? = null
private var eventLoopGroup: EventLoopGroup? = null
private var manageEventLoopGroup: Boolean = false
private var bootstrap: (Bootstrap.() -> Unit)? = null
private var ssl: (SslContextBuilder.() -> Unit)? = null
private var webSockets: (WebSocketClientProtocolConfig.Builder.() -> Unit)? = null

override fun bufferPool(pool: ObjectPool<ChunkBuffer>) {
this.bufferPool = pool
}

override fun channel(cls: KClass<out DuplexChannel>) {
this.channelFactory = ReflectiveChannelFactory(cls.java)
}

override fun channelFactory(factory: ChannelFactory<out DuplexChannel>) {
this.channelFactory = factory
}

override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) {
this.eventLoopGroup = group
this.manageEventLoopGroup = manage
}

override fun bootstrap(block: Bootstrap.() -> Unit) {
bootstrap = block
}

override fun ssl(block: SslContextBuilder.() -> Unit) {
ssl = block
}

override fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit) {
webSockets = block
}

fun build(context: CoroutineContext): NettyWebSocketClientTransportResources {
val group = eventLoopGroup ?: NioEventLoopGroup()
val factory = channelFactory ?: ReflectiveChannelFactory(NioSocketChannel::class.java)

val transportContext = group.asCoroutineDispatcher() + context.supervisorContext()
if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation {
group.shutdownGracefully().awaitFuture()
}

val sslContext = ssl?.let {
SslContextBuilder
.forClient()
.apply(it)
.build()
}

val bootstrap = Bootstrap().apply {
bootstrap?.invoke(this)
group(group)
channelFactory(factory)
}

return NettyWebSocketClientTransportResources(
coroutineContext = transportContext,
bufferPool = bufferPool,
sslContext = sslContext,
bootstrap = bootstrap,
webSocketConfig = webSockets
)
}
}

private class NettyWebSocketClientTransportResources(
private val coroutineContext: CoroutineContext,
private val bufferPool: ObjectPool<ChunkBuffer>,
private val sslContext: SslContext?,
private val bootstrap: Bootstrap,
private val webSocketConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?,
) {
fun buildTransport(config: WebSocketClientProtocolConfig): NettyWebSocketClientTransport {
return NettyWebSocketClientTransportImpl(
coroutineContext = coroutineContext,
config = config.toBuilder().apply {
webSocketConfig?.invoke(this)
}.build(),
bufferPool = bufferPool,
sslContext = sslContext,
bootstrap = bootstrap,
)
}

fun buildEngine(): NettyWebSocketClientTransport.Engine {
return NettyWebSocketClientTransportEngineImpl(
coroutineContext = coroutineContext,
bufferPool = bufferPool,
sslContext = sslContext,
bootstrap = bootstrap,
webSocketConfig = webSocketConfig,
)
}
}

private class NettyWebSocketClientTransportEngineImpl(
override val coroutineContext: CoroutineContext,
private val bufferPool: ObjectPool<ChunkBuffer>,
private val sslContext: SslContext?,
private val bootstrap: Bootstrap,
private val webSocketConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?,
) : NettyWebSocketClientTransport.Engine {
override fun createTransport(target: WebSocketClientProtocolConfig): NettyWebSocketClientTransport {
return NettyWebSocketClientTransportImpl(
coroutineContext = coroutineContext.supervisorContext(),
config = target.toBuilder().apply {
webSocketConfig?.invoke(this)
}.build(),
bufferPool = bufferPool,
sslContext = sslContext,
bootstrap = bootstrap,
)
}
}

private class NettyWebSocketClientTransportImpl(
override val coroutineContext: CoroutineContext,
override val config: WebSocketClientProtocolConfig,
private val bufferPool: ObjectPool<ChunkBuffer>,
private val sslContext: SslContext?,
private val bootstrap: Bootstrap,
) : NettyWebSocketClientTransport {
@RSocketTransportApi
override suspend fun connect(): RSocketTransportConnection {
val remoteAddress = InetSocketAddress(config.webSocketUri().host, config.webSocketUri().port)
val handler = NettyWebSocketChannelHandler(
bufferPool = bufferPool,
sslContext = sslContext,
remoteAddress = remoteAddress,
httpHandler = HttpClientCodec(),
webSocketHandler = WebSocketClientProtocolHandler(config),
)
val future = bootstrap.clone().apply {
remoteAddress(remoteAddress)
handler(handler)
}.connect()

future.awaitFuture()

return handler.connect(coroutineContext, future.channel() as DuplexChannel)
}
}
Loading

0 comments on commit 8da1a8b

Please sign in to comment.