Skip to content

Commit

Permalink
Add RabbitConnection for RabbitMQ (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
wasdennnoch authored Jul 7, 2024
1 parent 852c121 commit 5c40c43
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 44 deletions.
4 changes: 4 additions & 0 deletions latte/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ dependencies {
implementation("org.apache.kafka:kafka-clients:$kafkaVersion")
implementation("org.apache.kafka:kafka-streams:$kafkaVersion")

// RabbitMQ
val rabbitVersion = "5.20.0"
implementation("com.rabbitmq:amqp-client:$rabbitVersion")

// JSON
val moshiVersion = "1.14.0"
implementation("com.squareup.moshi:moshi:$moshiVersion")
Expand Down
27 changes: 24 additions & 3 deletions latte/src/main/java/gg/beemo/latte/broker/BrokerConnection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package gg.beemo.latte.broker

import gg.beemo.latte.logging.Log
import java.util.*
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.collections.HashSet
import kotlin.math.abs

fun interface TopicListener {
fun onMessage(topic: String, key: String, value: String, headers: BrokerMessageHeaders)
Expand All @@ -14,12 +17,25 @@ abstract class BrokerConnection {
abstract val serviceName: String
abstract val instanceId: String
abstract val supportsTopicHotSwap: Boolean
abstract val deferInitialTopicCreation: Boolean

protected val topicListeners: MutableMap<String, MutableSet<TopicListener>> = Collections.synchronizedMap(HashMap())
private val deferredTopicsToCreate: MutableSet<String> = Collections.synchronizedSet(HashSet())
private val hasStarted = AtomicBoolean(false)

private val log by Log

abstract suspend fun start()
suspend fun start() {
abstractStart()
hasStarted.set(true)
for (topic in deferredTopicsToCreate) {
createTopic(topic)
}
deferredTopicsToCreate.clear()
}

internal abstract suspend fun abstractStart()

open fun destroy() {
log.debug("Destroying BrokerConnection")
topicListeners.clear()
Expand Down Expand Up @@ -53,8 +69,13 @@ abstract class BrokerConnection {

internal fun on(topic: String, cb: TopicListener) {
topicListeners.computeIfAbsent(topic) {
log.debug("Creating new topic '{}'", topic)
createTopic(topic)
if (!hasStarted.get() && deferInitialTopicCreation) {
log.debug("Deferring creation of topic '{}' to after connected", topic)
deferredTopicsToCreate.add(topic)
} else {
log.debug("Creating new topic '{}'", topic)
createTopic(topic)
}
Collections.synchronizedSet(HashSet())
}.add(cb)
}
Expand Down
3 changes: 2 additions & 1 deletion latte/src/main/java/gg/beemo/latte/broker/LocalConnection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class LocalConnection(
) : BrokerConnection() {

override val supportsTopicHotSwap = true
override val deferInitialTopicCreation = false

override suspend fun abstractSend(
topic: String,
Expand All @@ -30,7 +31,7 @@ class LocalConnection(
return headers.messageId
}

override suspend fun start() {
override suspend fun abstractStart() {
// Nothing to start :)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class KafkaConnection(
) : BrokerConnection() {

override val supportsTopicHotSwap = false
override val deferInitialTopicCreation = false
private val kafkaHostsString = kafkaHosts.joinToString(",")
private val log by Log

Expand Down Expand Up @@ -81,7 +82,7 @@ class KafkaConnection(
return headers.messageId
}

override suspend fun start() {
override suspend fun abstractStart() {
check(!isRunning) { "KafkaConnection is already running!" }
log.debug("Starting Kafka Connection")
createTopics()
Expand Down
139 changes: 139 additions & 0 deletions latte/src/main/java/gg/beemo/latte/broker/rabbitmq/RabbitConnection.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package gg.beemo.latte.broker.rabbitmq

import com.rabbitmq.client.*
import gg.beemo.latte.broker.BrokerConnection
import gg.beemo.latte.broker.BrokerMessageHeaders
import gg.beemo.latte.broker.MessageId
import gg.beemo.latte.logging.Log
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import java.util.*
import java.util.concurrent.atomic.AtomicBoolean
import javax.net.ssl.SSLContext


private class ChannelData(
val channel: Channel,
val sendMutex: Mutex = Mutex(),
var isConsuming: AtomicBoolean = AtomicBoolean(false),
var consumerTag: String? = null,
)

class RabbitConnection(
rabbitHosts: Array<String>,
override val serviceName: String,
override val instanceId: String,
private val useTls: Boolean = false,
private val username: String = "guest",
private val password: String = "guest",
) : BrokerConnection() {

override val supportsTopicHotSwap = true
override val deferInitialTopicCreation = true
private val rabbitAddresses = rabbitHosts.map(Address::parseAddress)
private val log by Log

private var connection: Connection? = null
private val channels = Collections.synchronizedMap(HashMap<String, ChannelData>())

override suspend fun abstractStart() {
connection = ConnectionFactory().also {
if (useTls) {
it.useSslProtocol(SSLContext.getDefault())
it.enableHostnameVerification()
}
it.useNio()
it.username = username
it.password = password
it.isAutomaticRecoveryEnabled = true
it.isTopologyRecoveryEnabled = true
}.newConnection(rabbitAddresses, instanceId)
}

override suspend fun abstractSend(
topic: String,
key: String,
value: String,
headers: BrokerMessageHeaders
): MessageId {
if (shouldDispatchExternallyAfterShortCircuit(topic, key, value, headers)) {

val channelData = getOrCreateChannel(topic)
// RabbitMQ's channels are not thread-safe for sending. Consuming and sending
// through the same channel at the same time is fine though.
channelData.sendMutex.withLock {
val properties = AMQP.BasicProperties.Builder().apply {
// https://www.rabbitmq.com/docs/publishers#message-properties
deliveryMode(2) // Persistent
headers(headers.headers) // lol
}.build()
channelData.channel.basicPublish(topic, key, properties, value.toByteArray())
}

}

return headers.messageId
}

override fun destroy() {
log.debug("Destroying RabbitConnection")
connection?.close()
connection = null
super.destroy()
}

override fun createTopic(topic: String) {
val channelData = getOrCreateChannel(topic)
if (channelData.isConsuming.getAndSet(true)) {
return
}
val consumer = object : DefaultConsumer(channelData.channel) {

override fun handleDelivery(
consumerTag: String,
envelope: Envelope,
properties: AMQP.BasicProperties,
body: ByteArray
) {
val key = envelope.routingKey ?: ""
val value = String(body)
val headers = BrokerMessageHeaders(properties.headers.mapValues { it.value.toString() })
dispatchIncomingMessage(topic, key, value, headers)
channel.basicAck(envelope.deliveryTag, false)
}

override fun handleShutdownSignal(consumerTag: String, sig: ShutdownSignalException) {
if (sig.isInitiatedByApplication) {
return
}
log.error("RabbitMQ consumer for topic $topic has shut down unexpectedly", sig)
// The client _should_ automatically recover the connection
}
}
channelData.consumerTag = channelData.channel.basicConsume(createQueueName(topic), false, consumer)
}

override fun removeTopic(topic: String) {
val channel = channels.remove(topic)
channel?.channel?.queueDelete(createQueueName(topic))
channel?.channel?.close()
}

private fun getOrCreateChannel(topic: String): ChannelData {
return channels.computeIfAbsent(topic) {
val connection = checkNotNull(connection) { "Connection not open" }
val channel = connection.createChannel().apply {
val exchangeName = topic
val queueName = createQueueName(topic)
val routingKey = "#"
exchangeDeclare(exchangeName, BuiltinExchangeType.TOPIC, true)
queueDeclare(queueName, true, false, false, null)
queueBind(queueName, exchangeName, routingKey)
}
ChannelData(channel)
}
}

private fun createQueueName(topic: String) = "$serviceName.$instanceId.$topic"

}
76 changes: 40 additions & 36 deletions latte/src/main/java/gg/beemo/latte/broker/rpc/RpcClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.flow.single
import kotlinx.coroutines.launch
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
import kotlin.time.Duration
Expand Down Expand Up @@ -37,14 +39,6 @@ class RpcClient<RequestT, ResponseT>(

private val requestProducer = client.producer(topic, key, options, requestType, requestIsNullable)
private val requestConsumer = client.consumer(topic, key, options, requestType, requestIsNullable) { msg ->
val responseProducer = client.producer(
client.toResponseTopic(topic),
client.toResponseKey(key),
options,
responseType,
responseIsNullable,
)

suspend fun sendResponse(response: ResponseT?, status: RpcStatus, isException: Boolean, isUpdate: Boolean) {
val responseMsg = RpcResponseMessage(
client.toResponseTopic(topic),
Expand Down Expand Up @@ -77,15 +71,30 @@ class RpcClient<RequestT, ResponseT>(
return@consumer
} catch (ex: Exception) {
log.error(
"Uncaught RPC callback error while processing message ${msg.headers.messageId} " +
"Uncaught RPC callbac#k error while processing message ${msg.headers.messageId} " +
"with key '$key' in topic '$topic'",
ex,
)
return@consumer
} finally {
responseProducer.destroy()
}
}
private val responseProducer = client.producer(
client.toResponseTopic(topic),
client.toResponseKey(key),
options,
responseType,
responseIsNullable,
)
private val responseFlow = MutableSharedFlow<BaseBrokerMessage<ResponseT>>()
private val responseConsumer = client.consumer(
client.toResponseTopic(topic),
client.toResponseKey(key),
options,
responseType,
responseIsNullable,
) {
responseFlow.emit(it)
}

suspend fun call(
request: RequestT,
Expand All @@ -110,36 +119,29 @@ class RpcClient<RequestT, ResponseT>(
require(maxResponses > 0) { "maxResponses must be at least 1" }
}
return callbackFlow {
val cbFlow = this
val responseCounter = AtomicInteger(0)
val timeoutLatch = maxResponses?.let { SuspendingCountDownLatch(it) }
val messageId = AtomicReference<String?>(null)

val responseConsumer = client.consumer(
client.toResponseTopic(topic),
client.toResponseKey(key),
options,
responseType,
responseIsNullable,
) {
val msg = it.toRpcResponseMessage()
if (msg.headers.inReplyTo != messageId.get()) {
return@consumer
launch { // Asynchronously consume responses; gets cancelled with callbackFlow
responseFlow.collect {
val msg = it.toRpcResponseMessage()
if (msg.headers.inReplyTo != messageId.get()) {
return@collect
}
// Close the callbackFlow if we receive an exception
if (msg.headers.isException) {
cbFlow.close(RpcException(msg.headers.status))
return@collect
}
cbFlow.send(msg)
timeoutLatch?.countDown()
val count = responseCounter.incrementAndGet()
if (maxResponses != null && count >= maxResponses) {
cbFlow.close()
}
}
// Close the flow if we receive an exception
if (msg.headers.isException) {
close(RpcException(msg.headers.status))
return@consumer
}
send(msg)
timeoutLatch?.countDown()
val count = responseCounter.incrementAndGet()
if (maxResponses != null && count >= maxResponses) {
close()
}
}

invokeOnClose {
responseConsumer.destroy()
}

messageId.set(requestProducer.send(request, services, instances))
Expand All @@ -161,6 +163,8 @@ class RpcClient<RequestT, ResponseT>(
override fun doDestroy() {
requestProducer.destroy()
requestConsumer.destroy()
responseProducer.destroy()
responseConsumer.destroy()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object SharedRatelimitData {

enum class RatelimitType {
@Json(name = "global")
GLBOAL,
GLOBAL,

@Json(name = "identify")
IDENTIFY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TestBrokerClient(

val greetingRpc = rpc<GreetingRequest, GreetingResponse>(
topic = "rpc.greetings",
key = "greeting.requests",
key = "greet",
) {
log.info("greetingRpc received request: ${it.value}")
return@rpc RpcStatus.OK to GreetingResponse("Hello, ${it.value.name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class RatelimitClient(connection: BrokerConnection) : BrokerClient(connection) {

log.debug("Incoming {} quota request from service {}", type, service)
val provider = when (msg.type) {
SharedRatelimitData.RatelimitType.GLBOAL -> globalRatelimitProvider
SharedRatelimitData.RatelimitType.GLOBAL -> globalRatelimitProvider
SharedRatelimitData.RatelimitType.IDENTIFY -> identifyRatelimitProvider
else -> throw IllegalArgumentException("Unknown ratelimit type ${msg.type}")
}
Expand Down

0 comments on commit 5c40c43

Please sign in to comment.