Skip to content

Commit

Permalink
Add support for nats Options
Browse files Browse the repository at this point in the history
Deprecate NewRouter and NewRouterWithAddress functions, replaced
with generic Connect().
  • Loading branch information
Karimerto committed Mar 17, 2023
1 parent 1b06fda commit 290d912
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 42 deletions.
8 changes: 4 additions & 4 deletions middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestAuthMiddleware(t *testing.T) {

t.Run("accept login", func(t *testing.T) {
// Create router and connect to test server
nr, err := natsrouter.NewRouterWithAddress(s.Addr().String())
nr, err := natsrouter.Connect(s.Addr().String())
am := NewAuthMiddleware(func(token string) bool { return true })
nr = nr.Use(am.Auth)
if err != nil {
Expand Down Expand Up @@ -94,7 +94,7 @@ func TestAuthMiddleware(t *testing.T) {

t.Run("reject login", func(t *testing.T) {
// Create router and connect to test server
nr, err := natsrouter.NewRouterWithAddress(s.Addr().String())
nr, err := natsrouter.Connect(s.Addr().String())
am := NewAuthMiddleware(func(token string) bool { return false })
nr = nr.Use(am.Auth)
if err != nil {
Expand Down Expand Up @@ -145,7 +145,7 @@ func TestAuthMiddleware(t *testing.T) {
tag := "err"
format := "proto"

nr, err := natsrouter.NewRouterWithAddress(s.Addr().String(), natsrouter.WithErrorConfigString(tag, format))
nr, err := natsrouter.Connect(s.Addr().String(), natsrouter.WithErrorConfigString(tag, format))
am := NewAuthMiddleware(func(token string) bool { return false })
nr = nr.Use(am.Auth)
if err != nil {
Expand Down Expand Up @@ -185,7 +185,7 @@ func TestAuthMiddleware(t *testing.T) {

t.Run("missing login", func(t *testing.T) {
// Create router and connect to test server
nr, err := natsrouter.NewRouterWithAddress(s.Addr().String())
nr, err := natsrouter.Connect(s.Addr().String())
am := NewAuthMiddleware(func(token string) bool { return false })
nr = nr.Use(am.Auth)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions middleware/requestid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestRequestIdMiddleware(t *testing.T) {

t.Run("default request_id header", func(t *testing.T) {
// Create router and connect to test server
nr, err := natsrouter.NewRouterWithAddress(s.Addr().String())
nr, err := natsrouter.Connect(s.Addr().String())
nr = nr.Use(RequestIdMiddleware())
if err != nil {
t.Fatalf("Could not connect to NATS server: %v", err)
Expand Down Expand Up @@ -68,7 +68,7 @@ func TestRequestIdMiddleware(t *testing.T) {
headerTag := "reqid"

// Create router and connect to test server
nr, err := natsrouter.NewRouterWithAddress(s.Addr().String())
nr, err := natsrouter.Connect(s.Addr().String())
nr = nr.Use(RequestIdMiddleware(headerTag))
if err != nil {
t.Fatalf("Could not connect to NATS server: %v", err)
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestRequestIdMiddleware(t *testing.T) {

t.Run("missing request_id header", func(t *testing.T) {
// Create router and connect to test server
nr, err := natsrouter.NewRouterWithAddress(s.Addr().String())
nr, err := natsrouter.Connect(s.Addr().String())
nr = nr.Use(RequestIdMiddleware())
if err != nil {
t.Fatalf("Could not connect to NATS server: %v", err)
Expand Down
156 changes: 133 additions & 23 deletions natsrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package natsrouter
import (
"context"
"encoding/json"
"strings"
"sync"

"github.com/nats-io/nats.go"
Expand All @@ -32,32 +33,35 @@ type NatsMiddlewareFunc func(NatsCtxHandler) NatsCtxHandler
type NatsRouter struct {
nc *nats.Conn
mw []NatsMiddlewareFunc
options *RouterOptions
options RouterOptions
quit chan struct{}
chanWg sync.WaitGroup
closed chan struct{}
}

// Defines a struct for the router options, which currently only contains
// error config.
// Defines a struct for the router options, which currently contains error
// config, default request id tag (for error reporting) and optional list of
// NATS connection options.
type RouterOptions struct {
ec *ErrorConfig
requestIdTag string
ErrorConfig *ErrorConfig
RequestIdTag string
NatsOptions nats.Options
}

// Defines a function type that will be used to define options for the router.
type RouterOption func(options *RouterOptions)
type RouterOption func(*RouterOptions)

// Define error config in the router options.
func WithErrorConfig(ec *ErrorConfig) RouterOption {
return func(options *RouterOptions) {
options.ec = ec
return func(o *RouterOptions) {
o.ErrorConfig = ec
}
}

// Define error config as strings in the router options.
func WithErrorConfigString(tag, format string) RouterOption {
return func(options *RouterOptions) {
options.ec = &ErrorConfig{
return func(o *RouterOptions) {
o.ErrorConfig = &ErrorConfig{
Tag: tag,
Format: format,
}
Expand All @@ -66,27 +70,47 @@ func WithErrorConfigString(tag, format string) RouterOption {

// Define new request id header tag
func WithRequestIdTag(tag string) RouterOption {
return func(options *RouterOptions) {
options.requestIdTag = tag
return func(o *RouterOptions) {
o.RequestIdTag = tag
}
}

// Append one or more nats.Option to the connection, before connecting
func WithNatsOptions(nopts nats.Options) RouterOption {
return func(o *RouterOptions) {
o.NatsOptions = nopts
}
}

func GetDefaultRouterOptions() RouterOptions {
return RouterOptions{
&ErrorConfig{"error", "json"},
"request_id",
nats.GetDefaultOptions(),
}
}

// Create a new NatsRouter with a *nats.Conn and an optional list of
// RouterOptions functions. It sets the default RouterOptions to use a default
// ErrorConfig, and then iterates through each option function, calling it with
// the RouterOptions struct pointer to set any additional options.
//
// Deprecated: Use Connect instead. This does not support properly draining
// publications and subscriptions.
func NewRouter(nc *nats.Conn, options ...RouterOption) *NatsRouter {
router := &NatsRouter{
nc: nc,
options: &RouterOptions{
options: RouterOptions{
&ErrorConfig{"error", "json"},
"request_id",
nats.GetDefaultOptions(),
},
quit: make(chan struct{}),
quit: make(chan struct{}),
closed: make(chan struct{}),
}

for _, opt := range options {
opt(router.options)
opt(&router.options)
}

return router
Expand All @@ -97,6 +121,9 @@ func NewRouter(nc *nats.Conn, options ...RouterOption) *NatsRouter {
// address, and then calls NewRouter to create a new NatsRouter with the
// resulting *nats.Conn and optional RouterOptions. If there was an error
// connecting to the server, it returns nil and the error.
//
// Deprecated: Use Connect instead. This does not support properly draining
// publications and subscriptions.
func NewRouterWithAddress(addr string, options ...RouterOption) (*NatsRouter, error) {
nc, err := nats.Connect(addr)
if err != nil {
Expand All @@ -106,6 +133,89 @@ func NewRouterWithAddress(addr string, options ...RouterOption) (*NatsRouter, er
return NewRouter(nc, options...), nil
}

// Process the url string argument to Connect.
// Return an array of urls, even if only one.
func processUrlString(url string) []string {
urls := strings.Split(url, ",")
var j int
for _, s := range urls {
u := strings.TrimSpace(s)
if len(u) > 0 {
urls[j] = u
j++
}
}
return urls[:j]
}

func Connect(url string, options ...RouterOption) (*NatsRouter, error) {
opts := GetDefaultRouterOptions()
opts.NatsOptions.Servers = processUrlString(url)
for _, opt := range options {
if opt != nil {
opt(&opts)
}
}
return opts.Connect()
}

// Connect will attempt to connect to a NATS server with multiple options.
func (r RouterOptions) Connect() (*NatsRouter, error) {
// Check options, set defaults if necessary
if r.ErrorConfig == nil {
r.ErrorConfig = &ErrorConfig{"error", "json"}
}
if r.RequestIdTag == "" {
r.RequestIdTag = "request_id"
}

// Create router instance
router := &NatsRouter{
options: r,
quit: make(chan struct{}),
closed: make(chan struct{}),
}

// Set custom closed callback
if r.NatsOptions.ClosedCB != nil {
// Preserve original CB as well (via a closure)
r.NatsOptions.ClosedCB = func(orig nats.ConnHandler) nats.ConnHandler {
original := orig
return func(nc *nats.Conn) {
// First notify own channel
close(router.closed)
// And then call the original handler
original(nc)
}
}(r.NatsOptions.ClosedCB)
} else {
r.NatsOptions.ClosedCB = func(_ *nats.Conn) {
close(router.closed)
}
}

// Perform actual connection
nc, err := r.NatsOptions.Connect()
if err != nil {
return nil, err
}
router.nc = nc
return router, nil
}

// Drain pubs/subs and close connection to NATS server
func (n *NatsRouter) Drain() {
// First close any channel-based subscriptions
close(n.quit)
n.chanWg.Wait()

// Then start draining the connection
n.nc.Drain()

// Wait until it is done
<- n.closed
}

// Close connection to NATS server
func (n *NatsRouter) Close() {
close(n.quit)
Expand Down Expand Up @@ -174,12 +284,12 @@ func (n *NatsRouter) msgHandler(handler NatsCtxHandler) func(*nats.Msg) {
errData, _ := json.Marshal(handlerErr)

reply := nats.NewMsg(msg.Reply)
if len(n.options.requestIdTag) > 0 {
if reqId, ok := msg.Header[n.options.requestIdTag]; ok {
reply.Header.Add(n.options.requestIdTag, reqId[0])
if len(n.options.RequestIdTag) > 0 {
if reqId, ok := msg.Header[n.options.RequestIdTag]; ok {
reply.Header.Add(n.options.RequestIdTag, reqId[0])
}
}
reply.Header.Add(n.options.ec.Tag, n.options.ec.Format)
reply.Header.Add(n.options.ErrorConfig.Tag, n.options.ErrorConfig.Format)
reply.Data = errData

msg.RespondMsg(reply)
Expand Down Expand Up @@ -245,12 +355,12 @@ chanLoop:
errData, _ := json.Marshal(handlerErr)

reply := nats.NewMsg(msg.Reply)
if len(n.options.requestIdTag) > 0 {
if reqId, ok := msg.Header[n.options.requestIdTag]; ok {
reply.Header.Add(n.options.requestIdTag, reqId[0])
if len(n.options.RequestIdTag) > 0 {
if reqId, ok := msg.Header[n.options.RequestIdTag]; ok {
reply.Header.Add(n.options.RequestIdTag, reqId[0])
}
}
reply.Header.Add(n.options.ec.Tag, n.options.ec.Format)
reply.Header.Add(n.options.ErrorConfig.Tag, n.options.ErrorConfig.Format)
reply.Data = errData

msg.RespondMsg(reply)
Expand Down
Loading

0 comments on commit 290d912

Please sign in to comment.