Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for callbacks on individual messages #121

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Kafka/Internal/RdKafka.chs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ data RdKafkaMessageT = RdKafkaMessageT
, offset'RdKafkaMessageT :: Int64
, payload'RdKafkaMessageT :: Word8Ptr
, key'RdKafkaMessageT :: Word8Ptr
, opaque'RdKafkaMessageT :: Ptr ()
}
deriving (Show, Eq)

Expand All @@ -162,6 +163,7 @@ instance Storable RdKafkaMessageT where
<*> liftM fromIntegral ({#get rd_kafka_message_t->offset #} p)
<*> liftM castPtr ({#get rd_kafka_message_t->payload #} p)
<*> liftM castPtr ({#get rd_kafka_message_t->key #} p)
<*> liftM castPtr ({#get rd_kafka_message_t->_private #} p)
poke p x = do
{#set rd_kafka_message_t.err#} p (enumToCInt $ err'RdKafkaMessageT x)
{#set rd_kafka_message_t.rkt#} p (castPtr $ topic'RdKafkaMessageT x)
Expand All @@ -171,6 +173,7 @@ instance Storable RdKafkaMessageT where
{#set rd_kafka_message_t.offset#} p (fromIntegral $ offset'RdKafkaMessageT x)
{#set rd_kafka_message_t.payload#} p (castPtr $ payload'RdKafkaMessageT x)
{#set rd_kafka_message_t.key#} p (castPtr $ key'RdKafkaMessageT x)
{#set rd_kafka_message_t._private#} p (castPtr $ opaque'RdKafkaMessageT x)

{#pointer *rd_kafka_message_t as RdKafkaMessageTPtr foreign -> RdKafkaMessageT #}

Expand Down Expand Up @@ -893,7 +896,7 @@ rdKafkaConsumeStop topicPtr partition = do

{#fun rd_kafka_produce as ^
{`RdKafkaTopicTPtr', cIntConv `CInt32T', `Int', castPtr `Word8Ptr',
cIntConv `CSize', castPtr `Word8Ptr', cIntConv `CSize', castPtr `Word8Ptr'}
cIntConv `CSize', castPtr `Word8Ptr', cIntConv `CSize', castPtr `Ptr ()'}
-> `Int' #}

{#fun rd_kafka_produce_batch as ^
Expand Down
75 changes: 55 additions & 20 deletions src/Kafka/Producer.hs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE LambdaCase #-}
module Kafka.Producer
( module X
, runProducer
, newProducer
, produceMessage, produceMessageBatch
, produceMessage'
, flushProducer
, closeProducer
, KafkaProducer
Expand All @@ -25,11 +27,12 @@ import Foreign.ForeignPtr (newForeignPtr_, withForeignPtr)
import Foreign.Marshal.Array (withArrayLen)
import Foreign.Ptr (Ptr, nullPtr, plusPtr)
import Foreign.Storable (Storable (..))
import Foreign.StablePtr (newStablePtr, castStablePtrToPtr)
import Kafka.Internal.RdKafka (RdKafkaMessageT (..), RdKafkaRespErrT (..), RdKafkaTypeT (..), destroyUnmanagedRdKafkaTopic, newRdKafkaT, newUnmanagedRdKafkaTopicT, rdKafkaOutqLen, rdKafkaProduce, rdKafkaProduceBatch, rdKafkaSetLogLevel)
import Kafka.Internal.Setup (Kafka (..), KafkaConf (..), KafkaProps (..), TopicConf (..), TopicProps (..), kafkaConf, topicConf)
import Kafka.Internal.Shared (pollEvents)
import Kafka.Producer.Convert (copyMsgFlags, handleProduceErr, producePartitionCInt, producePartitionInt)
import Kafka.Producer.Types (KafkaProducer (..))
import Kafka.Producer.Convert (copyMsgFlags, handleProduceErr', producePartitionCInt, producePartitionInt)
import Kafka.Producer.Types (KafkaProducer (..), ImmediateError(..))

import Kafka.Producer.ProducerProperties as X
import Kafka.Producer.Types as X hiding (KafkaProducer)
Expand Down Expand Up @@ -60,6 +63,9 @@ newProducer pps = liftIO $ do
kc@(KafkaConf kc' _ _) <- kafkaConf (KafkaProps $ (ppKafkaProps pps))
tc <- topicConf (TopicProps $ (ppTopicProps pps))

-- add default delivery report callback
deliveryCallback (const mempty) kc

-- set callbacks
forM_ (ppCallbacks pps) (\setCb -> setCb kc)

Expand All @@ -78,23 +84,51 @@ produceMessage :: MonadIO m
=> KafkaProducer
-> ProducerRecord
-> m (Maybe KafkaError)
produceMessage kp@(KafkaProducer (Kafka k) _ (TopicConf tc)) m = liftIO $ do
pollEvents kp (Just $ Timeout 0) -- fire callbacks if any exist (handle delivery reports)
bracket (mkTopic $ prTopic m) clTopic withTopic
where
mkTopic (TopicName tn) = newUnmanagedRdKafkaTopicT k (Text.unpack tn) (Just tc)

clTopic = either (return . const ()) destroyUnmanagedRdKafkaTopic

withTopic (Left err) = return . Just . KafkaError $ Text.pack err
withTopic (Right t) =
withBS (prValue m) $ \payloadPtr payloadLength ->
withBS (prKey m) $ \keyPtr keyLength ->
handleProduceErr =<<
rdKafkaProduce t (producePartitionCInt (prPartition m))
copyMsgFlags payloadPtr (fromIntegral payloadLength)
keyPtr (fromIntegral keyLength) nullPtr

produceMessage kp m = produceMessage' kp m (pure . mempty) >>= adjustRes
where
adjustRes = \case
Right () -> pure Nothing
Left (ImmediateError err) -> pure (Just err)

-- | Sends a single message with a registered callback.
--
-- The callback can be a long running process, as it is forked by the thread
-- that handles the delivery reports.
--
produceMessage' :: MonadIO m
=> KafkaProducer
-> ProducerRecord
-> (DeliveryReport -> IO ())
-> m (Either ImmediateError ())
produceMessage' kp@(KafkaProducer (Kafka k) _ (TopicConf tc)) msg cb = liftIO $
fireCallbacks >> bracket (mkTopic . prTopic $ msg) closeTopic withTopic
where
fireCallbacks =
pollEvents kp . Just . Timeout $ 0

mkTopic (TopicName tn) =
newUnmanagedRdKafkaTopicT k (Text.unpack tn) (Just tc)

closeTopic = either mempty destroyUnmanagedRdKafkaTopic

withTopic (Left err) = return . Left . ImmediateError . KafkaError . Text.pack $ err
withTopic (Right topic) =
withBS (prValue msg) $ \payloadPtr payloadLength ->
withBS (prKey msg) $ \keyPtr keyLength -> do
callbackPtr <- newStablePtr cb
res <- handleProduceErr' =<< rdKafkaProduce
topic
(producePartitionCInt (prPartition msg))
copyMsgFlags
payloadPtr
(fromIntegral payloadLength)
keyPtr
(fromIntegral keyLength)
(castStablePtrToPtr callbackPtr)

pure $ case res of
Left err -> Left . ImmediateError $ err
Right () -> Right ()

-- | Sends a batch of messages.
-- Returns a list of messages which it was unable to send with corresponding errors.
Expand Down Expand Up @@ -146,6 +180,7 @@ produceMessageBatch kp@(KafkaProducer (Kafka k) _ (TopicConf tc)) messages = lif
, offset'RdKafkaMessageT = 0
, keyLen'RdKafkaMessageT = keyLength
, key'RdKafkaMessageT = keyPtr
, opaque'RdKafkaMessageT = nullPtr
}

-- | Closes the producer.
Expand Down
25 changes: 23 additions & 2 deletions src/Kafka/Producer/Callbacks.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
{-# LANGUAGE TypeApplications #-}
module Kafka.Producer.Callbacks
( deliveryCallback
, module X
)
where

import Control.Monad (void)
import Control.Concurrent (forkIO)
import Foreign.C.Error (getErrno)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.Storable (Storable(peek))
import Foreign.StablePtr (castPtrToStablePtr, deRefStablePtr)
import Kafka.Callbacks as X
import Kafka.Consumer.Types (Offset(..))
import Kafka.Internal.RdKafka (RdKafkaMessageT(..), RdKafkaRespErrT(..), rdKafkaConfSetDrMsgCb)
Expand All @@ -16,6 +20,12 @@ import Kafka.Producer.Types (ProducerRecord(..), DeliveryReport(..),
import Kafka.Types (KafkaError(..), TopicName(..))

-- | Sets the callback for delivery reports.
--
-- /Note: A callback should not be a long-running process as it blocks
-- librdkafka from continuing on the thread that handles the delivery
-- callbacks. For callbacks to individual messsages see
-- 'Kafka.Producer.produceMessage\''./
--
deliveryCallback :: (DeliveryReport -> IO ()) -> KafkaConf -> IO ()
deliveryCallback callback kc = rdKafkaConfSetDrMsgCb (getRdKafkaConf kc) realCb
where
Expand All @@ -25,9 +35,20 @@ deliveryCallback callback kc = rdKafkaConfSetDrMsgCb (getRdKafkaConf kc) realCb
then getErrno >>= (callback . NoMessageError . kafkaRespErr)
else do
s <- peek mptr
let cbPtr = opaque'RdKafkaMessageT s
if err'RdKafkaMessageT s /= RdKafkaRespErrNoError
then mkErrorReport s >>= callback
else mkSuccessReport s >>= callback
then mkErrorReport s >>= callbacks cbPtr
else mkSuccessReport s >>= callbacks cbPtr

callbacks cbPtr rep = do
callback rep
if cbPtr == nullPtr then
pure ()
else do
msgCb <- deRefStablePtr @(DeliveryReport -> IO ()) $ castPtrToStablePtr $ cbPtr
-- Here we fork the callback since it might be a longer action and
-- blocking here would block librdkafka from continuing its execution
void . forkIO $ msgCb rep

mkErrorReport :: RdKafkaMessageT -> IO DeliveryReport
mkErrorReport msg = do
Expand Down
7 changes: 7 additions & 0 deletions src/Kafka/Producer/Convert.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Kafka.Producer.Convert
, producePartitionInt
, producePartitionCInt
, handleProduceErr
, handleProduceErr'
)
where

Expand Down Expand Up @@ -31,3 +32,9 @@ handleProduceErr (- 1) = (Just . kafkaRespErr) <$> getErrno
handleProduceErr 0 = return Nothing
handleProduceErr _ = return $ Just KafkaInvalidReturnValue
{-# INLINE handleProduceErr #-}

handleProduceErr' :: Int -> IO (Either KafkaError ())
handleProduceErr' (- 1) = (Left . kafkaRespErr) <$> getErrno
handleProduceErr' 0 = return (Right ())
handleProduceErr' _ = return $ Left KafkaInvalidReturnValue
{-# INLINE handleProduceErr' #-}
11 changes: 9 additions & 2 deletions src/Kafka/Producer/Types.hs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Kafka.Producer.Types
( KafkaProducer(..)
, ProducerRecord(..)
, ProducePartition(..)
, DeliveryReport(..)
, ImmediateError(..)
)
where

Expand Down Expand Up @@ -47,6 +50,10 @@ data ProducePartition =
| UnassignedPartition
deriving (Show, Eq, Ord, Typeable, Generic)

-- | Data type representing an error that is caused by pre-flight conditions not being met
newtype ImmediateError = ImmediateError KafkaError
deriving newtype (Eq, Show)

data DeliveryReport
= DeliverySuccess ProducerRecord Offset
| DeliveryFailure ProducerRecord KafkaError
Expand Down
20 changes: 20 additions & 0 deletions tests-it/Kafka/IntegrationSpec.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

Expand All @@ -6,6 +7,7 @@ where

import Control.Monad (forM, forM_)
import Control.Monad.Loops
import Control.Concurrent.MVar (newEmptyMVar, putMVar, takeMVar)
import qualified Data.ByteString as BS
import Data.Either
import Data.Map (fromList)
Expand Down Expand Up @@ -113,6 +115,24 @@ spec = do
res <- sendMessages (testMessages testTopic) prod
res `shouldBe` Right ()

it "sends messages with callback to test topic" $ \prod -> do
var <- newEmptyMVar
let
msg = ProducerRecord
{ prTopic = TopicName "callback-topic"
, prPartition = UnassignedPartition
, prKey = Nothing
, prValue = Just "test from producer"
}

res <- produceMessage' prod msg (putMVar var)
res `shouldBe` Right ()
callbackRes <- flushProducer prod *> takeMVar var
callbackRes `shouldSatisfy` \case
DeliverySuccess _ _ -> True
DeliveryFailure _ _ -> False
NoMessageError _ -> False

specWithConsumer "Run consumer with async polling" (consumerProps <> groupId (makeGroupId "async")) runConsumerSpec
specWithConsumer "Run consumer with sync polling" (consumerProps <> groupId (makeGroupId "sync") <> callbackPollMode CallbackPollModeSync) runConsumerSpec

Expand Down