-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
852 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
109 changes: 109 additions & 0 deletions
109
...vmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
/* | ||
* Copyright 2015-2023 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.rsocket.kotlin.transport.netty.websocket | ||
|
||
import io.ktor.utils.io.core.* | ||
import io.ktor.utils.io.core.internal.* | ||
import io.ktor.utils.io.pool.* | ||
import io.netty.channel.* | ||
import io.netty.channel.socket.* | ||
import io.netty.handler.codec.http.* | ||
import io.netty.handler.codec.http.websocketx.* | ||
import io.netty.handler.ssl.* | ||
import io.rsocket.kotlin.internal.io.* | ||
import io.rsocket.kotlin.transport.* | ||
import kotlinx.coroutines.* | ||
import kotlinx.coroutines.channels.Channel | ||
import java.net.* | ||
import kotlin.coroutines.* | ||
|
||
|
||
internal class NettyWebSocketChannelHandler( | ||
private val bufferPool: ObjectPool<ChunkBuffer>, | ||
private val sslContext: SslContext?, | ||
private val remoteAddress: SocketAddress?, | ||
private val httpHandler: ChannelHandler, | ||
private val webSocketHandler: ChannelHandler, | ||
) : ChannelInitializer<DuplexChannel>() { | ||
private val frames = channelForCloseable<ByteReadPacket>(Channel.UNLIMITED) | ||
private val handshakeDeferred = CompletableDeferred<Unit>() | ||
|
||
@RSocketTransportApi | ||
suspend fun connect( | ||
context: CoroutineContext, | ||
channel: DuplexChannel, | ||
): NettyWebSocketConnection { | ||
handshakeDeferred.await() | ||
|
||
return NettyWebSocketConnection( | ||
coroutineContext = context.childContext(), | ||
bufferPool = bufferPool, | ||
channel = channel, | ||
frames = frames | ||
) | ||
} | ||
|
||
override fun initChannel(ch: DuplexChannel): Unit = with(ch.pipeline()) { | ||
// addFirst(LoggingHandler(LogLevel.INFO)) | ||
if (sslContext != null) { | ||
val sslHandler = if ( | ||
remoteAddress is InetSocketAddress && | ||
ch.parent() == null // not server | ||
) { | ||
sslContext.newHandler(ch.alloc(), remoteAddress.hostName, remoteAddress.port) | ||
} else { | ||
sslContext.newHandler(ch.alloc()) | ||
} | ||
addLast("ssl", sslHandler) | ||
} | ||
addLast("http", httpHandler) | ||
addLast(HttpObjectAggregator(65536)) //TODO size? | ||
addLast("websocket", webSocketHandler) | ||
|
||
addLast( | ||
"rsocket-frame-receiver", | ||
IncomingFramesChannelHandler() | ||
) | ||
} | ||
|
||
private inner class IncomingFramesChannelHandler : SimpleChannelInboundHandler<WebSocketFrame>() { | ||
override fun channelInactive(ctx: ChannelHandlerContext) { | ||
frames.close() //TODO? | ||
super.channelInactive(ctx) | ||
} | ||
|
||
override fun channelRead0(ctx: ChannelHandlerContext, msg: WebSocketFrame) { | ||
if (msg !is BinaryWebSocketFrame && msg !is TextWebSocketFrame) { | ||
error("wrong frame type") | ||
} | ||
|
||
frames.trySend(buildPacket { | ||
writeFully(msg.content().nioBuffer()) | ||
}).getOrThrow() | ||
} | ||
|
||
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { | ||
if ( | ||
evt is WebSocketServerProtocolHandler.HandshakeComplete || | ||
evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE | ||
) { | ||
handshakeDeferred.complete(Unit) | ||
} | ||
//TODO: handle timeout | ||
} | ||
} | ||
} |
230 changes: 230 additions & 0 deletions
230
...mMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
/* | ||
* Copyright 2015-2023 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.rsocket.kotlin.transport.netty.websocket | ||
|
||
import io.ktor.utils.io.core.internal.* | ||
import io.ktor.utils.io.pool.* | ||
import io.netty.bootstrap.* | ||
import io.netty.channel.* | ||
import io.netty.channel.ChannelFactory | ||
import io.netty.channel.nio.* | ||
import io.netty.channel.socket.* | ||
import io.netty.channel.socket.nio.* | ||
import io.netty.handler.codec.http.* | ||
import io.netty.handler.codec.http.websocketx.* | ||
import io.netty.handler.ssl.* | ||
import io.rsocket.kotlin.internal.io.* | ||
import io.rsocket.kotlin.transport.* | ||
import kotlinx.coroutines.* | ||
import java.net.* | ||
import kotlin.coroutines.* | ||
import kotlin.reflect.* | ||
|
||
public sealed interface NettyWebSocketClientTransport : RSocketClientTransport { | ||
public val config: WebSocketClientProtocolConfig | ||
|
||
public sealed interface Engine : RSocketTransportEngine<WebSocketClientProtocolConfig, NettyWebSocketClientTransport> { | ||
public fun createTransport( | ||
target: WebSocketClientProtocolConfig.Builder.() -> Unit, | ||
): NettyWebSocketClientTransport { | ||
return createTransport(WebSocketClientProtocolConfig.newBuilder().apply(target).build()) | ||
} | ||
} | ||
|
||
public sealed interface Builder { | ||
public fun bufferPool(pool: ObjectPool<ChunkBuffer>) | ||
|
||
public fun channel(cls: KClass<out DuplexChannel>) | ||
public fun channelFactory(factory: ChannelFactory<out DuplexChannel>) | ||
public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) | ||
|
||
public fun bootstrap(block: Bootstrap.() -> Unit) | ||
public fun ssl(block: SslContextBuilder.() -> Unit) | ||
public fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit) | ||
} | ||
|
||
public companion object Factory : | ||
RSocketTransportFactory<WebSocketClientProtocolConfig, NettyWebSocketClientTransport, Engine, Builder> { | ||
public operator fun invoke( | ||
context: CoroutineContext, | ||
target: WebSocketClientProtocolConfig.Builder.() -> Unit, | ||
block: Builder.() -> Unit, | ||
): NettyWebSocketClientTransport { | ||
return invoke(context, WebSocketClientProtocolConfig.newBuilder().apply(target).build(), block) | ||
} | ||
|
||
override fun invoke( | ||
context: CoroutineContext, | ||
target: WebSocketClientProtocolConfig, | ||
block: Builder.() -> Unit, | ||
): NettyWebSocketClientTransport { | ||
return NettyWebSocketClientTransportBuilderImpl().apply(block).build(context).buildTransport(target) | ||
} | ||
|
||
override fun Engine(context: CoroutineContext, block: Builder.() -> Unit): Engine { | ||
return NettyWebSocketClientTransportBuilderImpl().apply(block).build(context).buildEngine() | ||
} | ||
} | ||
} | ||
|
||
private class NettyWebSocketClientTransportBuilderImpl : NettyWebSocketClientTransport.Builder { | ||
private var bufferPool: ObjectPool<ChunkBuffer> = ChunkBuffer.Pool | ||
private var channelFactory: ChannelFactory<out DuplexChannel>? = null | ||
private var eventLoopGroup: EventLoopGroup? = null | ||
private var manageEventLoopGroup: Boolean = false | ||
private var bootstrap: (Bootstrap.() -> Unit)? = null | ||
private var ssl: (SslContextBuilder.() -> Unit)? = null | ||
private var webSockets: (WebSocketClientProtocolConfig.Builder.() -> Unit)? = null | ||
|
||
override fun bufferPool(pool: ObjectPool<ChunkBuffer>) { | ||
this.bufferPool = pool | ||
} | ||
|
||
override fun channel(cls: KClass<out DuplexChannel>) { | ||
this.channelFactory = ReflectiveChannelFactory(cls.java) | ||
} | ||
|
||
override fun channelFactory(factory: ChannelFactory<out DuplexChannel>) { | ||
this.channelFactory = factory | ||
} | ||
|
||
override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { | ||
this.eventLoopGroup = group | ||
this.manageEventLoopGroup = manage | ||
} | ||
|
||
override fun bootstrap(block: Bootstrap.() -> Unit) { | ||
bootstrap = block | ||
} | ||
|
||
override fun ssl(block: SslContextBuilder.() -> Unit) { | ||
ssl = block | ||
} | ||
|
||
override fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit) { | ||
webSockets = block | ||
} | ||
|
||
fun build(context: CoroutineContext): NettyWebSocketClientTransportResources { | ||
val group = eventLoopGroup ?: NioEventLoopGroup() | ||
val factory = channelFactory ?: ReflectiveChannelFactory(NioSocketChannel::class.java) | ||
|
||
val transportContext = group.asCoroutineDispatcher() + context.supervisorContext() | ||
if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { | ||
group.shutdownGracefully().awaitFuture() | ||
} | ||
|
||
val sslContext = ssl?.let { | ||
SslContextBuilder | ||
.forClient() | ||
.apply(it) | ||
.build() | ||
} | ||
|
||
val bootstrap = Bootstrap().apply { | ||
bootstrap?.invoke(this) | ||
group(group) | ||
channelFactory(factory) | ||
} | ||
|
||
return NettyWebSocketClientTransportResources( | ||
coroutineContext = transportContext, | ||
bufferPool = bufferPool, | ||
sslContext = sslContext, | ||
bootstrap = bootstrap, | ||
webSocketConfig = webSockets | ||
) | ||
} | ||
} | ||
|
||
private class NettyWebSocketClientTransportResources( | ||
private val coroutineContext: CoroutineContext, | ||
private val bufferPool: ObjectPool<ChunkBuffer>, | ||
private val sslContext: SslContext?, | ||
private val bootstrap: Bootstrap, | ||
private val webSocketConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?, | ||
) { | ||
fun buildTransport(config: WebSocketClientProtocolConfig): NettyWebSocketClientTransport { | ||
return NettyWebSocketClientTransportImpl( | ||
coroutineContext = coroutineContext, | ||
config = config.toBuilder().apply { | ||
webSocketConfig?.invoke(this) | ||
}.build(), | ||
bufferPool = bufferPool, | ||
sslContext = sslContext, | ||
bootstrap = bootstrap, | ||
) | ||
} | ||
|
||
fun buildEngine(): NettyWebSocketClientTransport.Engine { | ||
return NettyWebSocketClientTransportEngineImpl( | ||
coroutineContext = coroutineContext, | ||
bufferPool = bufferPool, | ||
sslContext = sslContext, | ||
bootstrap = bootstrap, | ||
webSocketConfig = webSocketConfig, | ||
) | ||
} | ||
} | ||
|
||
private class NettyWebSocketClientTransportEngineImpl( | ||
override val coroutineContext: CoroutineContext, | ||
private val bufferPool: ObjectPool<ChunkBuffer>, | ||
private val sslContext: SslContext?, | ||
private val bootstrap: Bootstrap, | ||
private val webSocketConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?, | ||
) : NettyWebSocketClientTransport.Engine { | ||
override fun createTransport(target: WebSocketClientProtocolConfig): NettyWebSocketClientTransport { | ||
return NettyWebSocketClientTransportImpl( | ||
coroutineContext = coroutineContext.supervisorContext(), | ||
config = target.toBuilder().apply { | ||
webSocketConfig?.invoke(this) | ||
}.build(), | ||
bufferPool = bufferPool, | ||
sslContext = sslContext, | ||
bootstrap = bootstrap, | ||
) | ||
} | ||
} | ||
|
||
private class NettyWebSocketClientTransportImpl( | ||
override val coroutineContext: CoroutineContext, | ||
override val config: WebSocketClientProtocolConfig, | ||
private val bufferPool: ObjectPool<ChunkBuffer>, | ||
private val sslContext: SslContext?, | ||
private val bootstrap: Bootstrap, | ||
) : NettyWebSocketClientTransport { | ||
@RSocketTransportApi | ||
override suspend fun connect(): RSocketTransportConnection { | ||
val remoteAddress = InetSocketAddress(config.webSocketUri().host, config.webSocketUri().port) | ||
val handler = NettyWebSocketChannelHandler( | ||
bufferPool = bufferPool, | ||
sslContext = sslContext, | ||
remoteAddress = remoteAddress, | ||
httpHandler = HttpClientCodec(), | ||
webSocketHandler = WebSocketClientProtocolHandler(config), | ||
) | ||
val future = bootstrap.clone().apply { | ||
remoteAddress(remoteAddress) | ||
handler(handler) | ||
}.connect() | ||
|
||
future.awaitFuture() | ||
|
||
return handler.connect(coroutineContext, future.channel() as DuplexChannel) | ||
} | ||
} |
Oops, something went wrong.