Skip to content

Commit

Permalink
Add support for channel-based subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
Karimerto committed Mar 9, 2023
1 parent 38fa4f2 commit c355b62
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 1 deletion.
81 changes: 81 additions & 0 deletions natsrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package natsrouter
import (
"context"
"encoding/json"
"sync"

"github.com/nats-io/nats.go"
)
Expand All @@ -32,6 +33,8 @@ type NatsRouter struct {
nc *nats.Conn
mw []NatsMiddlewareFunc
options *RouterOptions
quit chan struct{}
chanWg sync.WaitGroup
}

// Defines a struct for the router options, which currently only contains
Expand Down Expand Up @@ -79,6 +82,7 @@ func NewRouter(nc *nats.Conn, options ...RouterOption) *NatsRouter {
&ErrorConfig{"error", "json"},
"request_id",
},
quit: make(chan struct{}),
}

for _, opt := range options {
Expand All @@ -104,6 +108,8 @@ func NewRouterWithAddress(addr string, options ...RouterOption) (*NatsRouter, er

// Close connection to NATS server
func (n *NatsRouter) Close() {
close(n.quit)
n.chanWg.Wait()
n.nc.Close()
}

Expand All @@ -118,6 +124,7 @@ func (n *NatsRouter) WithMiddleware(fns ...NatsMiddlewareFunc) *NatsRouter {
nc: n.nc,
mw: append(n.mw, fns...),
options: n.options,
quit: make(chan struct{}),
}
}

Expand Down Expand Up @@ -180,6 +187,80 @@ func (n *NatsRouter) msgHandler(handler NatsCtxHandler) func(*nats.Msg) {
}
}

// Same as Subscribe, except uses channels. Note that error handling is
// available only for middleware, since the message is processed first by
// middleware and then inserted into the *NatsMsg channel.
func (n *NatsRouter) ChanSubscribe(subject string, ch chan *NatsMsg) (*nats.Subscription, error) {
intCh := make(chan *nats.Msg, 64)
n.chanWg.Add(1)
go n.chanMsgHandler(ch, intCh)
return n.nc.ChanSubscribe(subject, intCh)
}

// Same as QueueSubscribe, except uses channels. Note that error handling is
// available only for middleware, since the message is processed first by
// middleware and then inserted into the *NatsMsg channel.
func (n *NatsRouter) ChanQueueSubscribe(subject, group string, ch chan *NatsMsg) (*nats.Subscription, error) {
intCh := make(chan *nats.Msg, 64)
n.chanWg.Add(1)
go n.chanMsgHandler(ch, intCh)
return n.nc.ChanQueueSubscribe(subject, group, intCh)
}

// Handler that wraps function call with any registered middleware functions in
// reverse order. On any error, an error message is automatically sent as a
// response to the request.
func (n *NatsRouter) chanMsgHandler(ch chan *NatsMsg, intCh chan *nats.Msg) {
defer n.chanWg.Done()

handler := func(natsMsg *NatsMsg) error {
ch <- natsMsg
return nil
}

chanLoop:
for {
select {
case msg := <-intCh:
natsMsg := &NatsMsg{
msg,
context.Background(),
}

var wrappedHandler NatsCtxHandler = handler
for i := len(n.mw) - 1; i >= 0; i-- {
wrappedHandler = n.mw[i](wrappedHandler)
}
// Errors are only handled for the middleware
err := wrappedHandler(natsMsg)

if err != nil {
handlerErr, ok := err.(*HandlerError)
if !ok {
handlerErr = &HandlerError{
Message: err.Error(),
Code: 500,
}
}
errData, _ := json.Marshal(handlerErr)

reply := nats.NewMsg(msg.Reply)
if len(n.options.requestIdTag) > 0 {
if reqId, ok := msg.Header[n.options.requestIdTag]; ok {
reply.Header.Add(n.options.requestIdTag, reqId[0])
}
}
reply.Header.Add(n.options.ec.Tag, n.options.ec.Format)
reply.Data = errData

msg.RespondMsg(reply)
}
case <-n.quit:
break chanLoop
}
}
}

// Publish is a passthrough function for the `nats` Publish function
func (n *NatsRouter) Publish(subject string, data []byte) error {
return n.nc.Publish(subject, data)
Expand Down
79 changes: 79 additions & 0 deletions natsrouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,82 @@ func TestRequestId(t *testing.T) {
}
})
}

func TestChanSubscribe(t *testing.T) {
// Create test server and router
s, nr := getServerAndRouter(t)
defer s.Shutdown()

nc := nr.Conn()
defer nr.Close()

respond := func(ch chan *NatsMsg) {
msg := <-ch
err := msg.RespondWithOriginalHeaders(msg.Data)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}

t.Run("channel-based subscribe", func(t *testing.T) {
sub := nr.Subject("foo")
ch := make(chan *NatsMsg, 4)
// _, err := sub.Subscribe(emptyHandler)
_, err := sub.ChanSubscribe(ch)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// Start a function to read the request and send a response
go respond(ch)

// Create message and send a request
msg := nats.NewMsg("foo")
msg.Data = []byte("data")
reqId := "req-1"
msg.Header.Add("request_id", reqId)

reply, err := nc.RequestMsg(msg, 1*time.Second)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
got := reply.Header.Get("request_id")
if got != reqId {
t.Errorf("header request_id does not match, expected %s, received %s", reqId, got)
}
if !bytes.Equal(msg.Data, reply.Data) {
t.Errorf("responses do not match, expected %s, received %s", string(msg.Data), string(reply.Data))
}
})

t.Run("channel-based queue subscribe", func(t *testing.T) {
sub := nr.Queue("group").Subject("foo")
ch := make(chan *NatsMsg, 4)
// _, err := sub.Subscribe(emptyHandler)
_, err := sub.ChanSubscribe(ch)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// Start a function to read the request and send a response
go respond(ch)

// Create message and send a request
msg := nats.NewMsg("foo")
msg.Data = []byte("data")
reqId := "req-1"
msg.Header.Add("request_id", reqId)

reply, err := nc.RequestMsg(msg, 1*time.Second)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
got := reply.Header.Get("request_id")
if got != reqId {
t.Errorf("header request_id does not match, expected %s, received %s", reqId, got)
}
if !bytes.Equal(msg.Data, reply.Data) {
t.Errorf("responses do not match, expected %s, received %s", string(msg.Data), string(reply.Data))
}
})
}
5 changes: 5 additions & 0 deletions queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func (q *Queue) Subscribe(subject string, handler NatsCtxHandler) (*nats.Subscri
return q.n.QueueSubscribe(subject, q.group, handler)
}

// Same as Subscribe, with channel support
func (q *Queue) ChanSubscribe(subject string, ch chan *NatsMsg) (*nats.Subscription, error) {
return q.n.ChanQueueSubscribe(subject, q.group, ch)
}

// Create a new `Subject` object that is part of this `Queue` group
func (q *Queue) Subject(subjects ...string) *Subject {
return &Subject{
Expand Down
17 changes: 17 additions & 0 deletions subject.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,20 @@ func (s *Subject) Subscribe(handler NatsCtxHandler) (*nats.Subscription, error)
return s.n.Subscribe(subject, handler)
}
}

// Same as Subscribe, with channel support
func (s *Subject) ChanSubscribe(ch chan *NatsMsg) (*nats.Subscription, error) {
if s.queue != nil {
subject, err := s.getSubject()
if err != nil {
return nil, err
}
return s.queue.ChanSubscribe(subject, ch)
} else {
subject, err := s.getSubject()
if err != nil {
return nil, err
}
return s.n.ChanSubscribe(subject, ch)
}
}
2 changes: 1 addition & 1 deletion version.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ package natsrouter

// Version is the current release version.
func Version() string {
return "0.0.4"
return "0.0.5"
}

0 comments on commit c355b62

Please sign in to comment.