From 290d912e14476fee37b0ad6ae2a3265022795bbe Mon Sep 17 00:00:00 2001 From: Teemu Karimerto Date: Fri, 17 Mar 2023 20:00:57 +0200 Subject: [PATCH] Add support for nats Options Deprecate NewRouter and NewRouterWithAddress functions, replaced with generic Connect(). --- middleware/auth_test.go | 8 +- middleware/requestid_test.go | 6 +- natsrouter.go | 156 +++++++++++++++++++++++++++++------ natsrouter_test.go | 51 +++++++++--- version.go | 2 +- 5 files changed, 181 insertions(+), 42 deletions(-) diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 73ce2cc..050eef6 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -48,7 +48,7 @@ func TestAuthMiddleware(t *testing.T) { t.Run("accept login", func(t *testing.T) { // Create router and connect to test server - nr, err := natsrouter.NewRouterWithAddress(s.Addr().String()) + nr, err := natsrouter.Connect(s.Addr().String()) am := NewAuthMiddleware(func(token string) bool { return true }) nr = nr.Use(am.Auth) if err != nil { @@ -94,7 +94,7 @@ func TestAuthMiddleware(t *testing.T) { t.Run("reject login", func(t *testing.T) { // Create router and connect to test server - nr, err := natsrouter.NewRouterWithAddress(s.Addr().String()) + nr, err := natsrouter.Connect(s.Addr().String()) am := NewAuthMiddleware(func(token string) bool { return false }) nr = nr.Use(am.Auth) if err != nil { @@ -145,7 +145,7 @@ func TestAuthMiddleware(t *testing.T) { tag := "err" format := "proto" - nr, err := natsrouter.NewRouterWithAddress(s.Addr().String(), natsrouter.WithErrorConfigString(tag, format)) + nr, err := natsrouter.Connect(s.Addr().String(), natsrouter.WithErrorConfigString(tag, format)) am := NewAuthMiddleware(func(token string) bool { return false }) nr = nr.Use(am.Auth) if err != nil { @@ -185,7 +185,7 @@ func TestAuthMiddleware(t *testing.T) { t.Run("missing login", func(t *testing.T) { // Create router and connect to test server - nr, err := natsrouter.NewRouterWithAddress(s.Addr().String()) + nr, err := natsrouter.Connect(s.Addr().String()) am := NewAuthMiddleware(func(token string) bool { return false }) nr = nr.Use(am.Auth) if err != nil { diff --git a/middleware/requestid_test.go b/middleware/requestid_test.go index 9b3bb43..d78f092 100644 --- a/middleware/requestid_test.go +++ b/middleware/requestid_test.go @@ -17,7 +17,7 @@ func TestRequestIdMiddleware(t *testing.T) { t.Run("default request_id header", func(t *testing.T) { // Create router and connect to test server - nr, err := natsrouter.NewRouterWithAddress(s.Addr().String()) + nr, err := natsrouter.Connect(s.Addr().String()) nr = nr.Use(RequestIdMiddleware()) if err != nil { t.Fatalf("Could not connect to NATS server: %v", err) @@ -68,7 +68,7 @@ func TestRequestIdMiddleware(t *testing.T) { headerTag := "reqid" // Create router and connect to test server - nr, err := natsrouter.NewRouterWithAddress(s.Addr().String()) + nr, err := natsrouter.Connect(s.Addr().String()) nr = nr.Use(RequestIdMiddleware(headerTag)) if err != nil { t.Fatalf("Could not connect to NATS server: %v", err) @@ -116,7 +116,7 @@ func TestRequestIdMiddleware(t *testing.T) { t.Run("missing request_id header", func(t *testing.T) { // Create router and connect to test server - nr, err := natsrouter.NewRouterWithAddress(s.Addr().String()) + nr, err := natsrouter.Connect(s.Addr().String()) nr = nr.Use(RequestIdMiddleware()) if err != nil { t.Fatalf("Could not connect to NATS server: %v", err) diff --git a/natsrouter.go b/natsrouter.go index 265024d..fd3217e 100644 --- a/natsrouter.go +++ b/natsrouter.go @@ -16,6 +16,7 @@ package natsrouter import ( "context" "encoding/json" + "strings" "sync" "github.com/nats-io/nats.go" @@ -32,32 +33,35 @@ type NatsMiddlewareFunc func(NatsCtxHandler) NatsCtxHandler type NatsRouter struct { nc *nats.Conn mw []NatsMiddlewareFunc - options *RouterOptions + options RouterOptions quit chan struct{} chanWg sync.WaitGroup + closed chan struct{} } -// Defines a struct for the router options, which currently only contains -// error config. +// Defines a struct for the router options, which currently contains error +// config, default request id tag (for error reporting) and optional list of +// NATS connection options. type RouterOptions struct { - ec *ErrorConfig - requestIdTag string + ErrorConfig *ErrorConfig + RequestIdTag string + NatsOptions nats.Options } // Defines a function type that will be used to define options for the router. -type RouterOption func(options *RouterOptions) +type RouterOption func(*RouterOptions) // Define error config in the router options. func WithErrorConfig(ec *ErrorConfig) RouterOption { - return func(options *RouterOptions) { - options.ec = ec + return func(o *RouterOptions) { + o.ErrorConfig = ec } } // Define error config as strings in the router options. func WithErrorConfigString(tag, format string) RouterOption { - return func(options *RouterOptions) { - options.ec = &ErrorConfig{ + return func(o *RouterOptions) { + o.ErrorConfig = &ErrorConfig{ Tag: tag, Format: format, } @@ -66,8 +70,23 @@ func WithErrorConfigString(tag, format string) RouterOption { // Define new request id header tag func WithRequestIdTag(tag string) RouterOption { - return func(options *RouterOptions) { - options.requestIdTag = tag + return func(o *RouterOptions) { + o.RequestIdTag = tag + } +} + +// Append one or more nats.Option to the connection, before connecting +func WithNatsOptions(nopts nats.Options) RouterOption { + return func(o *RouterOptions) { + o.NatsOptions = nopts + } +} + +func GetDefaultRouterOptions() RouterOptions { + return RouterOptions{ + &ErrorConfig{"error", "json"}, + "request_id", + nats.GetDefaultOptions(), } } @@ -75,18 +94,23 @@ func WithRequestIdTag(tag string) RouterOption { // RouterOptions functions. It sets the default RouterOptions to use a default // ErrorConfig, and then iterates through each option function, calling it with // the RouterOptions struct pointer to set any additional options. +// +// Deprecated: Use Connect instead. This does not support properly draining +// publications and subscriptions. func NewRouter(nc *nats.Conn, options ...RouterOption) *NatsRouter { router := &NatsRouter{ nc: nc, - options: &RouterOptions{ + options: RouterOptions{ &ErrorConfig{"error", "json"}, "request_id", + nats.GetDefaultOptions(), }, - quit: make(chan struct{}), + quit: make(chan struct{}), + closed: make(chan struct{}), } for _, opt := range options { - opt(router.options) + opt(&router.options) } return router @@ -97,6 +121,9 @@ func NewRouter(nc *nats.Conn, options ...RouterOption) *NatsRouter { // address, and then calls NewRouter to create a new NatsRouter with the // resulting *nats.Conn and optional RouterOptions. If there was an error // connecting to the server, it returns nil and the error. +// +// Deprecated: Use Connect instead. This does not support properly draining +// publications and subscriptions. func NewRouterWithAddress(addr string, options ...RouterOption) (*NatsRouter, error) { nc, err := nats.Connect(addr) if err != nil { @@ -106,6 +133,89 @@ func NewRouterWithAddress(addr string, options ...RouterOption) (*NatsRouter, er return NewRouter(nc, options...), nil } +// Process the url string argument to Connect. +// Return an array of urls, even if only one. +func processUrlString(url string) []string { + urls := strings.Split(url, ",") + var j int + for _, s := range urls { + u := strings.TrimSpace(s) + if len(u) > 0 { + urls[j] = u + j++ + } + } + return urls[:j] +} + +func Connect(url string, options ...RouterOption) (*NatsRouter, error) { + opts := GetDefaultRouterOptions() + opts.NatsOptions.Servers = processUrlString(url) + for _, opt := range options { + if opt != nil { + opt(&opts) + } + } + return opts.Connect() +} + +// Connect will attempt to connect to a NATS server with multiple options. +func (r RouterOptions) Connect() (*NatsRouter, error) { + // Check options, set defaults if necessary + if r.ErrorConfig == nil { + r.ErrorConfig = &ErrorConfig{"error", "json"} + } + if r.RequestIdTag == "" { + r.RequestIdTag = "request_id" + } + + // Create router instance + router := &NatsRouter{ + options: r, + quit: make(chan struct{}), + closed: make(chan struct{}), + } + + // Set custom closed callback + if r.NatsOptions.ClosedCB != nil { + // Preserve original CB as well (via a closure) + r.NatsOptions.ClosedCB = func(orig nats.ConnHandler) nats.ConnHandler { + original := orig + return func(nc *nats.Conn) { + // First notify own channel + close(router.closed) + // And then call the original handler + original(nc) + } + }(r.NatsOptions.ClosedCB) + } else { + r.NatsOptions.ClosedCB = func(_ *nats.Conn) { + close(router.closed) + } + } + + // Perform actual connection + nc, err := r.NatsOptions.Connect() + if err != nil { + return nil, err + } + router.nc = nc + return router, nil +} + +// Drain pubs/subs and close connection to NATS server +func (n *NatsRouter) Drain() { + // First close any channel-based subscriptions + close(n.quit) + n.chanWg.Wait() + + // Then start draining the connection + n.nc.Drain() + + // Wait until it is done + <- n.closed +} + // Close connection to NATS server func (n *NatsRouter) Close() { close(n.quit) @@ -174,12 +284,12 @@ func (n *NatsRouter) msgHandler(handler NatsCtxHandler) func(*nats.Msg) { errData, _ := json.Marshal(handlerErr) reply := nats.NewMsg(msg.Reply) - if len(n.options.requestIdTag) > 0 { - if reqId, ok := msg.Header[n.options.requestIdTag]; ok { - reply.Header.Add(n.options.requestIdTag, reqId[0]) + if len(n.options.RequestIdTag) > 0 { + if reqId, ok := msg.Header[n.options.RequestIdTag]; ok { + reply.Header.Add(n.options.RequestIdTag, reqId[0]) } } - reply.Header.Add(n.options.ec.Tag, n.options.ec.Format) + reply.Header.Add(n.options.ErrorConfig.Tag, n.options.ErrorConfig.Format) reply.Data = errData msg.RespondMsg(reply) @@ -245,12 +355,12 @@ chanLoop: errData, _ := json.Marshal(handlerErr) reply := nats.NewMsg(msg.Reply) - if len(n.options.requestIdTag) > 0 { - if reqId, ok := msg.Header[n.options.requestIdTag]; ok { - reply.Header.Add(n.options.requestIdTag, reqId[0]) + if len(n.options.RequestIdTag) > 0 { + if reqId, ok := msg.Header[n.options.RequestIdTag]; ok { + reply.Header.Add(n.options.RequestIdTag, reqId[0]) } } - reply.Header.Add(n.options.ec.Tag, n.options.ec.Format) + reply.Header.Add(n.options.ErrorConfig.Tag, n.options.ErrorConfig.Format) reply.Data = errData msg.RespondMsg(reply) diff --git a/natsrouter_test.go b/natsrouter_test.go index fea726b..128ccf5 100644 --- a/natsrouter_test.go +++ b/natsrouter_test.go @@ -48,7 +48,7 @@ func TestRunServer(t *testing.T) { defer nc.Close() } -func TestNewNatsRouter(t *testing.T) { +func TestConnect(t *testing.T) { // Create test server opts := &server.Options{Host: "localhost", Port: server.RANDOM_PORT, NoSigs: true} s, err := runServer(opts) @@ -58,14 +58,14 @@ func TestNewNatsRouter(t *testing.T) { defer s.Shutdown() // Create router and connect to test server - nr, err := NewRouterWithAddress(s.Addr().String()) + nr, err := Connect(s.Addr().String()) if err != nil { t.Fatalf("Could not connect to NATS server: %v", err) } defer nr.Close() } -func TestNew(t *testing.T) { +func TestOptionsConnect(t *testing.T){ // Create test server opts := &server.Options{Host: "localhost", Port: server.RANDOM_PORT, NoSigs: true} s, err := runServer(opts) @@ -74,17 +74,46 @@ func TestNew(t *testing.T) { } defer s.Shutdown() - // Connect client - nc, err := nats.Connect(s.Addr().String()) + // Create router and connect to test server + rOpts := GetDefaultRouterOptions() + nr, err := rOpts.Connect() if err != nil { t.Fatalf("Could not connect to NATS server: %v", err) } - - // Create router - nr := NewRouter(nc) defer nr.Close() } +func TestDrain(t *testing.T){ + // Create test server + opts := &server.Options{Host: "localhost", Port: server.RANDOM_PORT, NoSigs: true} + s, err := runServer(opts) + if err != nil { + t.Fatalf("Could not start NATS server: %v", err) + } + defer s.Shutdown() + + // Create router and connect to test server + ch := make(chan struct{}) + rOpts := RouterOptions{ + NatsOptions: nats.Options{ + ClosedCB: func(_ *nats.Conn) { + close (ch) + }, + }, + } + nr, err := rOpts.Connect() + if err != nil { + t.Fatalf("Could not connect to NATS server: %v", err) + } + nr.Drain() + + select { + case <-ch: + default: + t.Error("Channel is not closed") + } +} + func getServer(t *testing.T) *server.Server { // Create test server opts := &server.Options{Host: "localhost", Port: server.RANDOM_PORT, NoSigs: true} @@ -100,7 +129,7 @@ func getServerAndRouter(t *testing.T) (*server.Server, *NatsRouter) { s := getServer(t) // Create router and connect to test server - nr, err := NewRouterWithAddress(s.Addr().String()) + nr, err := Connect(s.Addr().String()) if err != nil { t.Fatalf("Could not connect to NATS server: %v", err) } @@ -339,7 +368,7 @@ func TestError(t *testing.T) { // Create router and connect to test server tag := "err" format := "proto" - nr, err := NewRouterWithAddress(s.Addr().String(), WithErrorConfigString(tag, format)) + nr, err := Connect(s.Addr().String(), WithErrorConfigString(tag, format)) if err != nil { t.Fatalf("Could not connect to NATS server: %v", err) } @@ -419,7 +448,7 @@ func TestRequestId(t *testing.T) { // Create router and connect to test server tag := "reqid" - nr, err := NewRouterWithAddress(s.Addr().String(), WithRequestIdTag(tag)) + nr, err := Connect(s.Addr().String(), WithRequestIdTag(tag)) if err != nil { t.Fatalf("Could not connect to NATS server: %v", err) } diff --git a/version.go b/version.go index aba8155..6b29ccd 100644 --- a/version.go +++ b/version.go @@ -2,5 +2,5 @@ package natsrouter // Version is the current release version. func Version() string { - return "0.0.5" + return "0.1.0" }