diff --git a/server/accounts.go b/server/accounts.go index c8e7e248b1a..962b54d5d4c 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -1,4 +1,4 @@ -// Copyright 2018-2023 The NATS Authors +// Copyright 2018-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -128,6 +128,10 @@ type streamImport struct { claim *jwt.Import usePub bool invalid bool + // This is `allow_trace` and when true and message tracing is happening, + // we will trace egresses past the account boundary, if `false`, we stop + // at the account boundary. + atrc bool } const ClientInfoHdr = "Nats-Request-Info" @@ -209,6 +213,11 @@ type serviceExport struct { latency *serviceLatency rtmr *time.Timer respThresh time.Duration + // This is `allow_trace` and when true and message tracing is happening, + // when processing a service import we will go through account boundary + // and trace egresses on that other account. If `false`, we stop at the + // account boundary. + atrc bool } // Used to track service latency. @@ -2367,6 +2376,18 @@ func (a *Account) SetServiceExportResponseThreshold(export string, maxTime time. return nil } +func (a *Account) SetServiceExportAllowTrace(export string, allowTrace bool) error { + a.mu.Lock() + se := a.getServiceExport(export) + if se == nil { + a.mu.Unlock() + return fmt.Errorf("no export defined for %q", export) + } + se.atrc = allowTrace + a.mu.Unlock() + return nil +} + // This is for internal service import responses. func (a *Account) addRespServiceImport(dest *Account, to string, osi *serviceImport, tracking bool, header http.Header) *serviceImport { nrr := string(osi.acc.newServiceReply(tracking)) @@ -2405,6 +2426,10 @@ func (a *Account) addRespServiceImport(dest *Account, to string, osi *serviceImp // AddStreamImportWithClaim will add in the stream import from a specific account with optional token. func (a *Account) AddStreamImportWithClaim(account *Account, from, prefix string, imClaim *jwt.Import) error { + return a.addStreamImportWithClaim(account, from, prefix, false, imClaim) +} + +func (a *Account) addStreamImportWithClaim(account *Account, from, prefix string, allowTrace bool, imClaim *jwt.Import) error { if account == nil { return ErrMissingAccount } @@ -2427,7 +2452,7 @@ func (a *Account) AddStreamImportWithClaim(account *Account, from, prefix string } } - return a.AddMappedStreamImportWithClaim(account, from, prefix+from, imClaim) + return a.addMappedStreamImportWithClaim(account, from, prefix+from, allowTrace, imClaim) } // AddMappedStreamImport helper for AddMappedStreamImportWithClaim @@ -2437,6 +2462,10 @@ func (a *Account) AddMappedStreamImport(account *Account, from, to string) error // AddMappedStreamImportWithClaim will add in the stream import from a specific account with optional token. func (a *Account) AddMappedStreamImportWithClaim(account *Account, from, to string, imClaim *jwt.Import) error { + return a.addMappedStreamImportWithClaim(account, from, to, false, imClaim) +} + +func (a *Account) addMappedStreamImportWithClaim(account *Account, from, to string, allowTrace bool, imClaim *jwt.Import) error { if account == nil { return ErrMissingAccount } @@ -2478,7 +2507,11 @@ func (a *Account) AddMappedStreamImportWithClaim(account *Account, from, to stri a.mu.Unlock() return ErrStreamImportDuplicate } - a.imports.streams = append(a.imports.streams, &streamImport{account, from, to, tr, nil, imClaim, usePub, false}) + // TODO(ik): When AllowTrace is added to JWT, uncomment those lines: + // if imClaim != nil { + // allowTrace = imClaim.AllowTrace + // } + a.imports.streams = append(a.imports.streams, &streamImport{account, from, to, tr, nil, imClaim, usePub, false, allowTrace}) a.mu.Unlock() return nil } @@ -2496,7 +2529,7 @@ func (a *Account) isStreamImportDuplicate(acc *Account, from string) bool { // AddStreamImport will add in the stream import from a specific account. func (a *Account) AddStreamImport(account *Account, from, prefix string) error { - return a.AddStreamImportWithClaim(account, from, prefix, nil) + return a.addStreamImportWithClaim(account, from, prefix, false, nil) } // IsPublicExport is a placeholder to denote a public export. diff --git a/server/accounts_test.go b/server/accounts_test.go index 92778c66629..e48b92b169f 100644 --- a/server/accounts_test.go +++ b/server/accounts_test.go @@ -622,6 +622,13 @@ func TestAccountParseConfigImportsExports(t *testing.T) { if lis := len(natsAcc.imports.streams); lis != 2 { t.Fatalf("Expected 2 imported streams, got %d\n", lis) } + for _, si := range natsAcc.imports.streams { + if si.from == "public.synadia" { + require_True(t, si.atrc) + } else { + require_False(t, si.atrc) + } + } if lis := len(natsAcc.imports.services); lis != 1 { t.Fatalf("Expected 1 imported service, got %d\n", lis) } @@ -639,6 +646,7 @@ func TestAccountParseConfigImportsExports(t *testing.T) { if ea.respType != Streamed { t.Fatalf("Expected to get a Streamed response type, got %q", ea.respType) } + require_True(t, ea.atrc) ea = natsAcc.exports.services["nats.photo"] if ea == nil { t.Fatalf("Expected to get a non-nil exportAuth for service") @@ -646,6 +654,7 @@ func TestAccountParseConfigImportsExports(t *testing.T) { if ea.respType != Chunked { t.Fatalf("Expected to get a Chunked response type, got %q", ea.respType) } + require_False(t, ea.atrc) ea = natsAcc.exports.services["nats.add"] if ea == nil { t.Fatalf("Expected to get a non-nil exportAuth for service") @@ -653,6 +662,7 @@ func TestAccountParseConfigImportsExports(t *testing.T) { if ea.respType != Singleton { t.Fatalf("Expected to get a Singleton response type, got %q", ea.respType) } + require_True(t, ea.atrc) if synAcc == nil { t.Fatalf("Error retrieving account for 'synadia'") diff --git a/server/client.go b/server/client.go index 348c0478fec..d3466ae820c 100644 --- a/server/client.go +++ b/server/client.go @@ -1,4 +1,4 @@ -// Copyright 2012-2023 The NATS Authors +// Copyright 2012-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -2476,7 +2476,7 @@ func (c *client) msgParts(data []byte) (hdr []byte, msg []byte) { } // Header pubs take form HPUB [reply] \r\n -func (c *client) processHeaderPub(arg []byte) error { +func (c *client) processHeaderPub(arg, remaining []byte) error { if !c.headers { return ErrMsgHeadersNotSupported } @@ -2534,6 +2534,16 @@ func (c *client) processHeaderPub(arg []byte) error { maxPayload := atomic.LoadInt32(&c.mpay) // Use int64() to avoid int32 overrun... if maxPayload != jwt.NoLimit && int64(c.pa.size) > int64(maxPayload) { + // If we are given the remaining read buffer (since we do blind reads + // we may have the beginning of the message header/payload), we will + // look for the tracing header and if found, we will generate a + // trace event with the max payload ingress error. + // Do this only for CLIENT connections. + if c.kind == CLIENT && len(remaining) > 0 { + if td := getHeader(MsgTraceSendTo, remaining); len(td) > 0 { + c.initAndSendIngressErrEvent(remaining, string(td), ErrMaxPayload) + } + } c.maxPayloadViolation(c.pa.size, maxPayload) return ErrMaxPayload } @@ -3324,23 +3334,33 @@ var needFlush = struct{}{} // deliverMsg will deliver a message to a matching subscription and its underlying client. // We process all connection/client types. mh is the part that will be protocol/client specific. func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, subject, reply, mh, msg []byte, gwrply bool) bool { + // Check if message tracing is enabled. + mt, traceOnly := c.isMsgTraceEnabled() + + client := sub.client // Check sub client and check echo. Only do this if not a service import. - if sub.client == nil || (c == sub.client && !sub.client.echo && !sub.si) { + if client == nil || (c == client && !client.echo && !sub.si) { + if client != nil && mt != nil { + client.mu.Lock() + mt.addEgressEvent(client, sub, errMsgTraceNoEcho) + client.mu.Unlock() + } return false } - client := sub.client client.mu.Lock() // Check if we have a subscribe deny clause. This will trigger us to check the subject // for a match against the denied subjects. if client.mperms != nil && client.checkDenySub(string(subject)) { + mt.addEgressEvent(client, sub, errMsgTraceSubDeny) client.mu.Unlock() return false } // New race detector forces this now. if sub.isClosed() { + mt.addEgressEvent(client, sub, errMsgTraceSubClosed) client.mu.Unlock() return false } @@ -3348,15 +3368,56 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su // Check if we are a leafnode and have perms to check. if client.kind == LEAF && client.perms != nil { if !client.pubAllowedFullCheck(string(subject), true, true) { + mt.addEgressEvent(client, sub, errMsgTracePubViolation) client.mu.Unlock() client.Debugf("Not permitted to deliver to %q", subject) return false } } + var mtErr string + if mt != nil { + // For non internal subscription, and if the remote does not support + // the tracing feature... + if sub.icb == nil && !client.msgTraceSupport() { + if traceOnly { + // We are not sending the message at all because the user + // expects a trace-only and the remote does not support + // tracing, which means that it would process/deliver this + // message, which may break applications. + // Add the Egress with the no-support error message. + mt.addEgressEvent(client, sub, errMsgTraceOnlyNoSupport) + client.mu.Unlock() + return false + } + // If we are doing delivery, we will still forward the message, + // but we add an error to the Egress event to hint that one should + // not expect a tracing event from that remote. + mtErr = errMsgTraceNoSupport + } + // For ROUTER, GATEWAY and LEAF, even if we intend to do tracing only, + // we will still deliver the message. The remote side will + // generate an event based on what happened on that server. + if traceOnly && (client.kind == ROUTER || client.kind == GATEWAY || client.kind == LEAF) { + traceOnly = false + } + // If we skip delivery and this is not for a service import, we are done. + if traceOnly && (sub.icb == nil || c.noIcb) { + mt.addEgressEvent(client, sub, _EMPTY_) + client.mu.Unlock() + // Although the message is not actually delivered, for the + // purpose of "didDeliver", we need to return "true" here. + return true + } + } + srv := client.srv - sub.nm++ + // We don't want to bump the number of delivered messages to the subscription + // if we are doing trace-only (since really we are not sending it to the sub). + if !traceOnly { + sub.nm++ + } // Check if we should auto-unsubscribe. if sub.max > 0 { @@ -3380,6 +3441,7 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su defer client.unsubscribe(client.acc, sub, true, true) } else if sub.nm > sub.max { client.Debugf("Auto-unsubscribe limit [%d] exceeded", sub.max) + mt.addEgressEvent(client, sub, errMsgTraceAutoSubExceeded) client.mu.Unlock() client.unsubscribe(client.acc, sub, true, true) if shouldForward { @@ -3407,10 +3469,14 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su msgSize -= int64(LEN_CR_LF) } - // No atomic needed since accessed under client lock. - // Monitor is reading those also under client's lock. - client.outMsgs++ - client.outBytes += msgSize + // We do not update the outbound stats if we are doing trace only since + // this message will not be sent out. + if !traceOnly { + // No atomic needed since accessed under client lock. + // Monitor is reading those also under client's lock. + client.outMsgs++ + client.outBytes += msgSize + } // Check for internal subscriptions. if sub.icb != nil && !c.noIcb { @@ -3443,10 +3509,17 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su // Check for closed connection if client.isClosed() { + mt.addEgressEvent(client, sub, errMsgTraceClientClosed) client.mu.Unlock() return false } + // We have passed cases where we could possibly fail to deliver. + // Do not call for service-import. + if mt != nil && sub.icb == nil { + mt.addEgressEvent(client, sub, mtErr) + } + // Do a fast check here to see if we should be tracking this from a latency // perspective. This will be for a request being received for an exported service. // This needs to be from a non-client (otherwise tracking happens at requestor). @@ -4121,6 +4194,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt } } siAcc := si.acc + allowTrace := si.se != nil && si.se.atrc acc.mu.RUnlock() // We have a special case where JetStream pulls in all service imports through one export. @@ -4131,6 +4205,8 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt return } + mt, traceOnly := c.isMsgTraceEnabled() + var nrr []byte var rsi *serviceImport @@ -4259,23 +4335,54 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt var lrts [routeTargetInit]routeTarget c.in.rts = lrts[:0] + var skipProcessing bool + // If message tracing enabled, add the service import trace. + if mt != nil { + mt.addServiceImportEvent(siAcc.GetName(), string(pacopy.subject), to) + // If we are not allowing tracing and doing trace only, we stop at this level. + if !allowTrace { + if traceOnly { + skipProcessing = true + } else { + // We are going to do normal processing, and possibly chainning + // with other server imports, but the rest won't be traced. + // We do so by setting the c.pa.trace to nil (it will be restored + // with c.pa = pacopy). + c.pa.trace = nil + // We also need to disable the trace destination header so that + // if message is routed, it does not initialize tracing in the + // remote. + pos := mt.disableTraceHeader(c, msg) + defer mt.enableTraceHeader(c, msg, pos) + } + } + } + var didDeliver bool - // If this is not a gateway connection but gateway is enabled, - // try to send this converted message to all gateways. - if c.srv.gateway.enabled { - flags |= pmrCollectQueueNames - var queues [][]byte - didDeliver, queues = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) - didDeliver = c.sendMsgToGateways(siAcc, msg, []byte(to), nrr, queues) || didDeliver - } else { - didDeliver, _ = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) + if !skipProcessing { + // If this is not a gateway connection but gateway is enabled, + // try to send this converted message to all gateways. + if c.srv.gateway.enabled { + flags |= pmrCollectQueueNames + var queues [][]byte + didDeliver, queues = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) + didDeliver = c.sendMsgToGateways(siAcc, msg, []byte(to), nrr, queues) || didDeliver + } else { + didDeliver, _ = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) + } } // Restore to original values. c.in.rts = orts c.pa = pacopy + // If this was a message trace but we skip last-mile delivery, we need to + // do the remove, so: + if mt != nil && traceOnly && didDeliver { + didDeliver = false + } + // Determine if we should remove this service import. This is for response service imports. // We will remove if we did not deliver, or if we are a response service import and we are // a singleton, or we have an EOF message. @@ -4422,6 +4529,8 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } } + mt, traceOnly := c.isMsgTraceEnabled() + // Loop over all normal subscriptions that match. for _, sub := range r.psubs { // Check if this is a send to a ROUTER. We now process @@ -4450,6 +4559,11 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, // Assume delivery subject is the normal subject to this point. dsubj = subj + // We may need to disable tracing, by setting c.pa.trace to `nil` + // before the call to deliverMsg, if so, this will indicate that + // we need to put it back. + var restorePaTrace bool + // Check for stream import mapped subs (shadow subs). These apply to local subs only. if sub.im != nil { // If this message was a service import do not re-export to an exported stream. @@ -4465,6 +4579,25 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, dsubj = append(_dsubj[:0], sub.im.to...) } + if mt != nil { + mt.addStreamExportEvent(sub.client, dsubj) + // If allow_trace is false... + if !sub.im.atrc { + // If we are doing only message tracing, we can move to the + // next sub. + if traceOnly { + // Although the message was not delivered, for the purpose + // of didDeliver, we need to set to true (to avoid possible + // no responders). + didDeliver = true + continue + } + // If we are delivering the message, we need to disable tracing + // before calling deliverMsg(). + c.pa.trace, restorePaTrace = nil, true + } + } + // Make sure deliver is set if inbound from a route. if remapped && (c.kind == GATEWAY || c.kind == ROUTER || c.kind == LEAF) { deliver = subj @@ -4491,6 +4624,9 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } didDeliver = true } + if restorePaTrace { + c.pa.trace = mt + } } // Set these up to optionally filter based on the queue lists. @@ -4597,6 +4733,13 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, // Assume delivery subject is normal subject to this point. dsubj = subj + + // We may need to disable tracing, by setting c.pa.trace to `nil` + // before the call to deliverMsg, if so, this will indicate that + // we need to put it back. + var restorePaTrace bool + var skipDelivery bool + // Check for stream import mapped subs. These apply to local subs only. if sub.im != nil { // If this message was a service import do not re-export to an exported stream. @@ -4611,6 +4754,23 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } else { dsubj = append(_dsubj[:0], sub.im.to...) } + + if mt != nil { + mt.addStreamExportEvent(sub.client, dsubj) + // If allow_trace is false... + if !sub.im.atrc { + // If we are doing only message tracing, we are done + // with this queue group. + if traceOnly { + skipDelivery = true + } else { + // If we are delivering, we need to disable tracing + // before the call to deliverMsg() + c.pa.trace, restorePaTrace = nil, true + } + } + } + // Make sure deliver is set if inbound from a route. if remapped && (c.kind == GATEWAY || c.kind == ROUTER || c.kind == LEAF) { deliver = subj @@ -4623,11 +4783,20 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver, } } - mh := c.msgHeader(dsubj, creply, sub) - if c.deliverMsg(prodIsMQTT, sub, acc, subject, creply, mh, msg, rplyHasGWPrefix) { - if sub.icb == nil { + var delivered bool + if !skipDelivery { + mh := c.msgHeader(dsubj, creply, sub) + delivered = c.deliverMsg(prodIsMQTT, sub, acc, subject, creply, mh, msg, rplyHasGWPrefix) + if restorePaTrace { + c.pa.trace = mt + } + } + if skipDelivery || delivered { + // Update only if not skipped. + if !skipDelivery && sub.icb == nil { dlvMsgs++ } + // Do the rest even when message delivery was skipped. didDeliver = true // Clear rsub rsub = nil @@ -4668,6 +4837,16 @@ sendToRoutesOrLeafs: // Copy off original pa in case it changes. pa := c.pa + if mt != nil { + // We are going to replace "pa" with our copy of c.pa, but to restore + // to the original copy of c.pa, we need to save it again. + cpa := pa + msg = mt.setOriginAccountHeaderIfNeeded(c, acc, msg) + defer func() { c.pa = cpa }() + // Update pa with our current c.pa state. + pa = c.pa + } + // We address by index to avoid struct copy. // We have inline structs for memory layout and cache coherency. for i := range c.in.rts { @@ -4704,6 +4883,11 @@ sendToRoutesOrLeafs: } } + if mt != nil { + dmsg = mt.setHopHeader(c, dmsg) + hset = true + } + mh := c.msgHeaderForRouteOrLeaf(subject, reply, rt, acc) if c.deliverMsg(prodIsMQTT, rt.sub, acc, subject, reply, mh, dmsg, false) { if rt.sub.icb == nil { @@ -4750,7 +4934,11 @@ func (c *client) checkLeafClientInfoHeader(msg []byte) (dmsg []byte, setHdr bool } func (c *client) pubPermissionViolation(subject []byte) { - c.sendErr(fmt.Sprintf("Permissions Violation for Publish to %q", subject)) + errTxt := fmt.Sprintf("Permissions Violation for Publish to %q", subject) + if mt, _ := c.isMsgTraceEnabled(); mt != nil { + mt.setIngressError(errTxt) + } + c.sendErr(errTxt) c.Errorf("Publish Violation - %s, Subject %q", c.getAuthUser(), subject) } @@ -4770,7 +4958,11 @@ func (c *client) subPermissionViolation(sub *subscription) { } func (c *client) replySubjectViolation(reply []byte) { - c.sendErr(fmt.Sprintf("Permissions Violation for Publish with Reply of %q", reply)) + errTxt := fmt.Sprintf("Permissions Violation for Publish with Reply of %q", reply) + if mt, _ := c.isMsgTraceEnabled(); mt != nil { + mt.setIngressError(errTxt) + } + c.sendErr(errTxt) c.Errorf("Publish Violation - %s, Reply %q", c.getAuthUser(), reply) } diff --git a/server/config_check_test.go b/server/config_check_test.go index a9ec00cf1ae..f1cf644fb9b 100644 --- a/server/config_check_test.go +++ b/server/config_check_test.go @@ -1344,6 +1344,76 @@ func TestConfigCheck(t *testing.T) { errorLine: 11, errorPos: 25, }, + { + name: "when setting allow_trace on a stream export (after)", + config: ` + system_account = sys + accounts { + sys { users = [ {user: sys, pass: "" } ] } + + nats.io: { + users = [ { user : bar, pass: "" } ] + exports = [ { stream: "nats.add", allow_trace: true } ] + } + } + `, + err: errors.New(`Detected allow_trace directive on non-service`), + errorLine: 8, + errorPos: 55, + }, + { + name: "when setting allow_trace on a stream export (before)", + config: ` + system_account = sys + accounts { + sys { users = [ {user: sys, pass: "" } ] } + + nats.io: { + users = [ { user : bar, pass: "" } ] + exports = [ { allow_trace: true, stream: "nats.add" } ] + } + } + `, + err: errors.New(`Detected allow_trace directive on non-service`), + errorLine: 8, + errorPos: 35, + }, + { + name: "when setting allow_trace on a service import (after)", + config: ` + accounts { + A: { + users = [ {user: user1, pass: ""} ] + exports = [{service: "foo"}] + } + B: { + users = [ {user: user2, pass: ""} ] + imports = [ { service: {account: "A", subject: "foo"}, allow_trace: true } ] + } + } + `, + err: errors.New(`Detected allow_trace directive on a non-stream`), + errorLine: 9, + errorPos: 76, + }, + { + name: "when setting allow_trace on a service import (before)", + config: ` + accounts { + A: { + users = [ {user: user1, pass: ""} ] + exports = [{service: "foo"}] + } + B: { + users = [ {user: user2, pass: ""} ] + imports = [ { allow_trace: true, service: {account: "A", subject: "foo"} } ] + } + } + `, + err: errors.New(`Detected allow_trace directive on a non-stream`), + errorLine: 9, + errorPos: 35, + }, { name: "when using duplicate service import subject", config: ` diff --git a/server/configs/accounts.conf b/server/configs/accounts.conf index f9586a19668..f7c97cf5ffd 100644 --- a/server/configs/accounts.conf +++ b/server/configs/accounts.conf @@ -33,15 +33,15 @@ accounts: { ] imports = [ - {stream: {account: "synadia", subject:"public.synadia"}, prefix: "imports.synadia"} + {stream: {account: "synadia", subject:"public.synadia"}, prefix: "imports.synadia", allow_trace: true} {stream: {account: "synadia", subject:"synadia.private.*"}} {service: {account: "synadia", subject: "pub.special.request"}, to: "synadia.request"} ] exports = [ - {service: "nats.time", response: stream} + {service: "nats.time", response: stream, allow_trace: true} {service: "nats.photo", response: chunked} - {service: "nats.add", response: singleton, accounts: [cncf]} + {service: "nats.add", response: singleton, accounts: [cncf], allow_trace: true} {service: "nats.sub"} ] } diff --git a/server/events.go b/server/events.go index 391e677cad8..dad167c8564 100644 --- a/server/events.go +++ b/server/events.go @@ -1,4 +1,4 @@ -// Copyright 2018-2023 The NATS Authors +// Copyright 2018-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -647,7 +647,7 @@ func (s *Server) sendInternalAccountMsgWithReply(a *Account, subject, reply stri } // Send system style message to an account scope. -func (s *Server) sendInternalAccountSysMsg(a *Account, subj string, si *ServerInfo, msg interface{}) { +func (s *Server) sendInternalAccountSysMsg(a *Account, subj string, si *ServerInfo, msg any, ct compressionType) { s.mu.RLock() if s.sys == nil || s.sys.sendq == nil || a == nil { s.mu.RUnlock() @@ -660,7 +660,7 @@ func (s *Server) sendInternalAccountSysMsg(a *Account, subj string, si *ServerIn c := a.internalClient() a.mu.Unlock() - sendq.push(newPubMsg(c, subj, _EMPTY_, si, nil, msg, noCompression, false, false)) + sendq.push(newPubMsg(c, subj, _EMPTY_, si, nil, msg, ct, false, false)) } // This will queue up a message to be sent. @@ -2356,7 +2356,7 @@ func (s *Server) sendAccountAuthErrorEvent(c *client, acc *Account, reason strin } c.mu.Unlock() - s.sendInternalAccountSysMsg(acc, authErrorAccountEventSubj, &m.Server, &m) + s.sendInternalAccountSysMsg(acc, authErrorAccountEventSubj, &m.Server, &m, noCompression) } // Internal message callback. diff --git a/server/gateway.go b/server/gateway.go index f5f154700c9..0d11715c60f 100644 --- a/server/gateway.go +++ b/server/gateway.go @@ -1,4 +1,4 @@ -// Copyright 2018-2023 The NATS Authors +// Copyright 2018-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -216,6 +216,8 @@ type gateway struct { // interest-only mode "immediately", so the outbound should disregard // the optimistic mode when checking for interest. interestOnlyMode bool + // Name of the remote server + remoteName string } // Outbound subject interest entry. @@ -511,6 +513,7 @@ func (s *Server) startGatewayAcceptLoop() { Gateway: opts.Gateway.Name, GatewayNRP: true, Headers: s.supportsHeaders(), + Proto: s.getServerProto(), } // Unless in some tests we want to keep the old behavior, we are now // (since v2.9.0) indicate that this server will switch all accounts @@ -983,6 +986,10 @@ func (c *client) processGatewayInfo(info *Info) { } if isFirstINFO { c.opts.Name = info.ID + // Get the protocol version from the INFO protocol. This will be checked + // to see if this connection supports message tracing for instance. + c.opts.Protocol = info.Proto + c.gw.remoteName = info.Name } c.mu.Unlock() @@ -2454,6 +2461,14 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr if len(gws) == 0 { return false } + + mt, _ := c.isMsgTraceEnabled() + if mt != nil { + pa := c.pa + msg = mt.setOriginAccountHeaderIfNeeded(c, acc, msg) + defer func() { c.pa = pa }() + } + var ( queuesa = [512]byte{} queues = queuesa[:0] @@ -2546,6 +2561,11 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr mreply = append(mreply, reply...) } } + + if mt != nil { + msg = mt.setHopHeader(c, msg) + } + // Setup the message header. // Make sure we are an 'R' proto by default c.msgb[0] = 'R' diff --git a/server/jetstream_cluster.go b/server/jetstream_cluster.go index e7756cd1144..97437ee175f 100644 --- a/server/jetstream_cluster.go +++ b/server/jetstream_cluster.go @@ -2872,8 +2872,15 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco continue } + var mt *msgTrace + // If not recovering, see if we find a message trace object for this + // sequence. Only the leader that has proposed this entry will have + // stored the trace info. + if !isRecovering { + mt = mset.getAndDeleteMsgTrace(lseq) + } // Process the actual message here. - if err := mset.processJetStreamMsg(subject, reply, hdr, msg, lseq, ts); err != nil { + if err := mset.processJetStreamMsg(subject, reply, hdr, msg, lseq, ts, mt); err != nil { if err == errLastSeqMismatch { var state StreamState mset.store.FastState(&state) @@ -2883,7 +2890,7 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco if state.Msgs == 0 { mset.store.Compact(lseq + 1) // Retry - err = mset.processJetStreamMsg(subject, reply, hdr, msg, lseq, ts) + err = mset.processJetStreamMsg(subject, reply, hdr, msg, lseq, ts, mt) } } @@ -7461,7 +7468,7 @@ func (mset *stream) checkAllowMsgCompress(peers []string) { const streamLagWarnThreshold = 10_000 // processClusteredMsg will propose the inbound message to the underlying raft group. -func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg []byte) error { +func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg []byte, mt *msgTrace) (retErr error) { // For possible error response. var response []byte @@ -7474,8 +7481,23 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ mset.mu.RUnlock() // This should not happen but possible now that we allow scale up, and scale down where this could trigger. - if node == nil { - return mset.processJetStreamMsg(subject, reply, hdr, msg, 0, 0) + // + // We also invoke this in clustering mode for message tracing when not + // performing message delivery. + if node == nil || mt.traceOnly() { + return mset.processJetStreamMsg(subject, reply, hdr, msg, 0, 0, mt) + } + + // If message tracing (with message delivery), we will need to send the + // event on exit in case there was an error (if message was not proposed). + // Otherwise, the event will be sent from processJetStreamMsg when + // invoked by the leader (from applyStreamEntries). + if mt != nil { + defer func() { + if retErr != nil { + mt.sendEventFromJetStream(retErr) + } + }() } // Check that we are the leader. This can be false if we have scaled up from an R1 that had inbound queued messages. @@ -7629,6 +7651,14 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ mset.clseq = lseq + clfs } esm := encodeStreamMsgAllowCompress(subject, reply, hdr, msg, mset.clseq, time.Now().UnixNano(), mset.compressOK) + var mtKey uint64 + if mt != nil { + mtKey = mset.clseq + if mset.mt == nil { + mset.mt = make(map[uint64]*msgTrace) + } + mset.mt[mtKey] = mt + } mset.clseq++ // Do proposal. @@ -7646,6 +7676,9 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ mset.clMu.Unlock() if err != nil { + if mt != nil { + mset.getAndDeleteMsgTrace(mtKey) + } if canRespond { var resp = &JSPubAckResponse{PubAck: &PubAck{Stream: mset.cfg.Name}} resp.Error = &ApiError{Code: 503, Description: err.Error()} @@ -7662,6 +7695,19 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ return err } +func (mset *stream) getAndDeleteMsgTrace(lseq uint64) *msgTrace { + if mset == nil { + return nil + } + mset.clMu.Lock() + mt, ok := mset.mt[lseq] + if ok { + delete(mset.mt, lseq) + } + mset.clMu.Unlock() + return mt +} + // For requesting messages post raft snapshot to catch up streams post server restart. // Any deleted msgs etc will be handled inline on catchup. type streamSyncRequest struct { diff --git a/server/leafnode.go b/server/leafnode.go index 02bf4bd873a..1be72160cad 100644 --- a/server/leafnode.go +++ b/server/leafnode.go @@ -719,7 +719,7 @@ func (s *Server) startLeafNodeAcceptLoop() { Headers: s.supportsHeaders(), JetStream: opts.JetStream, Domain: opts.JetStreamDomain, - Proto: 1, // Fixed for now. + Proto: s.getServerProto(), InfoOnConnect: true, } // If we have selected a random port... @@ -783,6 +783,7 @@ func (c *client) sendLeafConnect(clusterName string, headers bool) error { DenyPub: c.leaf.remote.DenyImports, Compression: c.leaf.compression, RemoteAccount: c.acc.GetName(), + Proto: c.srv.getServerProto(), } // If a signature callback is specified, this takes precedence over anything else. @@ -1296,6 +1297,10 @@ func (c *client) processLeafnodeInfo(info *Info) { } c.leaf.remoteDomain = info.Domain c.leaf.remoteCluster = info.Cluster + // We send the protocol version in the INFO protocol. + // Keep track of it, so we know if this connection supports message + // tracing for instance. + c.opts.Protocol = info.Proto } // For both initial INFO and async INFO protocols, Possibly @@ -1729,6 +1734,14 @@ type leafConnectInfo struct { // Tells the accept side which account the remote is binding to. RemoteAccount string `json:"remote_account,omitempty"` + + // The accept side of a LEAF connection, unlike ROUTER and GATEWAY, receives + // only the CONNECT protocol, and no INFO. So we need to send the protocol + // version as part of the CONNECT. It will indicate if a connection supports + // some features, such as message tracing. + // We use `protocol` as the JSON tag, so this is automatically unmarshal'ed + // in the low level process CONNECT. + Proto int `json:"protocol,omitempty"` } // processLeafNodeConnect will process the inbound connect args. diff --git a/server/msgtrace.go b/server/msgtrace.go new file mode 100644 index 00000000000..2db2db28966 --- /dev/null +++ b/server/msgtrace.go @@ -0,0 +1,683 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync/atomic" + "time" +) + +const ( + MsgTraceSendTo = "Nats-Trace-Dest" + MsgTraceHop = "Nats-Trace-Hop" + MsgTraceOriginAccount = "Nats-Trace-Origin-Account" + MsgTraceOnly = "Nats-Trace-Only" +) + +type MsgTraceType string + +// Type of message trace events in the MsgTraceEvents list. +// This is needed to unmarshal the list. +const ( + MsgTraceIngressType = "in" + MsgTraceSubjectMappingType = "sm" + MsgTraceStreamExportType = "se" + MsgTraceServiceImportType = "si" + MsgTraceJetStreamType = "js" + MsgTraceEgressType = "eg" +) + +type MsgTraceEvent struct { + Server ServerInfo `json:"server"` + Request MsgTraceRequest `json:"request"` + Hops int `json:"hops,omitempty"` + Events MsgTraceEvents `json:"events"` +} + +type MsgTraceRequest struct { + Header http.Header `json:"header,omitempty"` + MsgSize int `json:"msgsize,omitempty"` +} + +type MsgTraceEvents []MsgTrace + +type MsgTrace interface { + new() MsgTrace + typ() MsgTraceType +} + +type MsgTraceBase struct { + Type MsgTraceType `json:"type"` + Timestamp time.Time `json:"ts"` +} + +type MsgTraceIngress struct { + MsgTraceBase + Kind int `json:"kind"` + CID uint64 `json:"cid"` + Name string `json:"name,omitempty"` + Account string `json:"acc"` + Subject string `json:"subj"` + Error string `json:"error,omitempty"` +} + +type MsgTraceSubjectMapping struct { + MsgTraceBase + MappedTo string `json:"to"` +} + +type MsgTraceStreamExport struct { + MsgTraceBase + Account string `json:"acc"` + To string `json:"to"` +} + +type MsgTraceServiceImport struct { + MsgTraceBase + Account string `json:"acc"` + From string `json:"from"` + To string `json:"to"` +} + +type MsgTraceJetStream struct { + MsgTraceBase + Stream string `json:"stream"` + Subject string `json:"subject,omitempty"` + NoInterest bool `json:"nointerest,omitempty"` + Error string `json:"error,omitempty"` +} + +type MsgTraceEgress struct { + MsgTraceBase + Kind int `json:"kind"` + CID uint64 `json:"cid"` + Name string `json:"name,omitempty"` + Hop string `json:"hop,omitempty"` + Account string `json:"acc,omitempty"` + Subscription string `json:"sub,omitempty"` + Queue string `json:"queue,omitempty"` + Error string `json:"error,omitempty"` + + // This is for applications that unmarshal the trace events + // and want to link an egress to route/leaf/gateway with + // the MsgTraceEvent from that server. + Link *MsgTraceEvent `json:"-"` +} + +// ------------------------------------------------------------- + +func (t MsgTraceBase) typ() MsgTraceType { return t.Type } +func (_ MsgTraceIngress) new() MsgTrace { return &MsgTraceIngress{} } +func (_ MsgTraceSubjectMapping) new() MsgTrace { return &MsgTraceSubjectMapping{} } +func (_ MsgTraceStreamExport) new() MsgTrace { return &MsgTraceStreamExport{} } +func (_ MsgTraceServiceImport) new() MsgTrace { return &MsgTraceServiceImport{} } +func (_ MsgTraceJetStream) new() MsgTrace { return &MsgTraceJetStream{} } +func (_ MsgTraceEgress) new() MsgTrace { return &MsgTraceEgress{} } + +var msgTraceInterfaces = map[MsgTraceType]MsgTrace{ + MsgTraceIngressType: MsgTraceIngress{}, + MsgTraceSubjectMappingType: MsgTraceSubjectMapping{}, + MsgTraceStreamExportType: MsgTraceStreamExport{}, + MsgTraceServiceImportType: MsgTraceServiceImport{}, + MsgTraceJetStreamType: MsgTraceJetStream{}, + MsgTraceEgressType: MsgTraceEgress{}, +} + +func (t *MsgTraceEvents) UnmarshalJSON(data []byte) error { + var raw []json.RawMessage + err := json.Unmarshal(data, &raw) + if err != nil { + return err + } + *t = make(MsgTraceEvents, len(raw)) + var tt MsgTraceBase + for i, r := range raw { + if err = json.Unmarshal(r, &tt); err != nil { + return err + } + tr, ok := msgTraceInterfaces[tt.Type] + if !ok { + return fmt.Errorf("Unknown trace type %v", tt.Type) + } + te := tr.new() + if err := json.Unmarshal(r, te); err != nil { + return err + } + (*t)[i] = te + } + return nil +} + +func getTraceAs[T MsgTrace](e any) *T { + v, ok := e.(*T) + if ok { + return v + } + return nil +} + +func (t *MsgTraceEvent) Ingress() *MsgTraceIngress { + if len(t.Events) < 1 { + return nil + } + return getTraceAs[MsgTraceIngress](t.Events[0]) +} + +func (t *MsgTraceEvent) SubjectMapping() *MsgTraceSubjectMapping { + for _, e := range t.Events { + if e.typ() == MsgTraceSubjectMappingType { + return getTraceAs[MsgTraceSubjectMapping](e) + } + } + return nil +} + +func (t *MsgTraceEvent) StreamExports() []*MsgTraceStreamExport { + var se []*MsgTraceStreamExport + for _, e := range t.Events { + if e.typ() == MsgTraceStreamExportType { + se = append(se, getTraceAs[MsgTraceStreamExport](e)) + } + } + return se +} + +func (t *MsgTraceEvent) ServiceImports() []*MsgTraceServiceImport { + var si []*MsgTraceServiceImport + for _, e := range t.Events { + if e.typ() == MsgTraceServiceImportType { + si = append(si, getTraceAs[MsgTraceServiceImport](e)) + } + } + return si +} + +func (t *MsgTraceEvent) JetStream() *MsgTraceJetStream { + for _, e := range t.Events { + if e.typ() == MsgTraceJetStreamType { + return getTraceAs[MsgTraceJetStream](e) + } + } + return nil +} + +func (t *MsgTraceEvent) Egresses() []*MsgTraceEgress { + var eg []*MsgTraceEgress + for _, e := range t.Events { + if e.typ() == MsgTraceEgressType { + eg = append(eg, getTraceAs[MsgTraceEgress](e)) + } + } + return eg +} + +const ( + errMsgTraceOnlyNoSupport = "Not delivered because remote does not support message tracing" + errMsgTraceNoSupport = "Message delivered but remote does not support message tracing so no trace event generated from there" + errMsgTraceNoEcho = "Not delivered because of no echo" + errMsgTracePubViolation = "Not delivered because publish denied for this subject" + errMsgTraceSubDeny = "Not delivered because subscription denies this subject" + errMsgTraceSubClosed = "Not delivered because subscription is closed" + errMsgTraceClientClosed = "Not delivered because client is closed" + errMsgTraceAutoSubExceeded = "Not delivered because auto-unsubscribe exceeded" +) + +type msgTrace struct { + ready int32 + srv *Server + acc *Account + // Origin account name, set only if acc is nil when acc lookup failed. + oan string + dest string + event *MsgTraceEvent + js *MsgTraceJetStream + hop string + nhop string + tonly bool // Will only trace the message, not do delivery. + ct compressionType +} + +// This will be false outside of the tests, so when building the server binary, +// any code where you see `if msgTraceRunInTests` statement will be compiled +// out, so this will have no performance penalty. +var ( + msgTraceRunInTests bool + msgTraceCheckSupport bool +) + +// Returns the message trace object, if message is being traced, +// and `true` if we want to only trace, not actually deliver the message. +func (c *client) isMsgTraceEnabled() (*msgTrace, bool) { + t := c.pa.trace + if t == nil { + return nil, false + } + return t, t.tonly +} + +// For LEAF/ROUTER/GATEWAY, return false if the remote does not support +// message tracing (important if the tracing requests trace-only). +func (c *client) msgTraceSupport() bool { + // Exclude client connection from the protocol check. + return c.kind == CLIENT || c.opts.Protocol >= MsgTraceProto +} + +func getConnName(c *client) string { + switch c.kind { + case ROUTER: + if n := c.route.remoteName; n != _EMPTY_ { + return n + } + case GATEWAY: + if n := c.gw.remoteName; n != _EMPTY_ { + return n + } + case LEAF: + if n := c.leaf.remoteServer; n != _EMPTY_ { + return n + } + } + return c.opts.Name +} + +func getCompressionType(cts string) compressionType { + if cts == _EMPTY_ { + return noCompression + } + cts = strings.ToLower(cts) + if strings.Contains(cts, "snappy") || strings.Contains(cts, "s2") { + return snappyCompression + } + if strings.Contains(cts, "gzip") { + return gzipCompression + } + return unsupportedCompression +} + +func (c *client) initMsgTrace() *msgTrace { + // The code in the "if" statement is only running in test mode. + if msgTraceRunInTests { + // Check the type of client that tries to initialize a trace struct. + if !(c.kind == CLIENT || c.kind == ROUTER || c.kind == GATEWAY || c.kind == LEAF) { + panic(fmt.Errorf("Unexpected client type %q trying to initialize msgTrace", c.kindString())) + } + // In some tests, we want to make a server behave like an old server + // and so even if a trace header is received, we want the server to + // simply ignore it. + if msgTraceCheckSupport { + if c.srv == nil || c.srv.getServerProto() < MsgTraceProto { + return nil + } + } + } + if c.pa.hdr <= 0 { + return nil + } + hdr := c.msgBuf[:c.pa.hdr] + // Do not call c.parseState.getHeader() yet for performance reasons. + // We first do a "manual" search of the "send-to" destination's header. + // If not present, no need to lift the message headers. + td := getHeader(MsgTraceSendTo, hdr) + if len(td) <= 0 { + return nil + } + // Now we know that this is a message that requested tracing, we + // will lift the headers since we also need to transfer them to + // the produced trace message. + headers := c.parseState.getHeader() + if headers == nil { + return nil + } + ct := getCompressionType(headers.Get(acceptEncodingHeader)) + var traceOnly bool + if to := headers.Get(MsgTraceOnly); to != _EMPTY_ { + tos := strings.ToLower(to) + switch tos { + case "1", "true", "on": + traceOnly = true + } + } + var ( + // Account to use when sending the trace event + acc *Account + // Ingress' account name + ian string + // Origin account name + oan string + // The hop "id", taken from headers only when not from CLIENT + hop string + ) + if c.kind == ROUTER || c.kind == GATEWAY || c.kind == LEAF { + // The ingress account name will always be c.pa.account, but `acc` may + // be different if we have an origin account header. + if c.kind == LEAF { + ian = c.acc.GetName() + } else { + ian = string(c.pa.account) + } + // The remote will have set the origin account header only if the + // message changed account (think of service imports). + oan = headers.Get(MsgTraceOriginAccount) + if oan == _EMPTY_ { + // For LEAF or ROUTER with pinned-account, we can use the c.acc. + if c.kind == LEAF || (c.kind == ROUTER && len(c.route.accName) > 0) { + acc = c.acc + } else { + // We will lookup account with c.pa.account (or ian). + oan = ian + } + } + // Unless we already got the account, we need to look it up. + if acc == nil { + // We don't want to do account resolving here, and we have to return + // a msgTrace object because if we don't and if the user wanted to do + // trace-only, the message would end-up being delivered. + if acci, ok := c.srv.accounts.Load(oan); ok { + acc = acci.(*Account) + // Since we have looked-up the account, we don't need oan, so + // clear it in case it was set. + oan = _EMPTY_ + } else { + c.Errorf("Account %q was not found, won't be able to trace events", oan) + } + } + // Check the hop header + hop = headers.Get(MsgTraceHop) + } else { + acc = c.acc + ian = acc.GetName() + } + c.pa.trace = &msgTrace{ + srv: c.srv, + acc: acc, + oan: oan, + dest: string(td), + ct: ct, + hop: hop, + event: &MsgTraceEvent{ + Request: MsgTraceRequest{ + Header: headers, + MsgSize: c.pa.size, + }, + Events: append(MsgTraceEvents(nil), &MsgTraceIngress{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceIngressType, + Timestamp: time.Now(), + }, + Kind: c.kind, + CID: c.cid, + Name: getConnName(c), + Account: ian, + Subject: string(c.pa.subject), + }), + }, + tonly: traceOnly, + } + return c.pa.trace +} + +// Special case where we create a trace event before parsing the message. +// This is for cases where the connection will be closed when detecting +// an error during early message processing (for instance max payload). +func (c *client) initAndSendIngressErrEvent(hdr []byte, dest string, ingressError error) { + if ingressError == nil { + return + } + ct := getAcceptEncoding(hdr) + t := &msgTrace{ + srv: c.srv, + acc: c.acc, + dest: dest, + ct: ct, + event: &MsgTraceEvent{ + Request: MsgTraceRequest{MsgSize: c.pa.size}, + Events: append(MsgTraceEvents(nil), &MsgTraceIngress{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceIngressType, + Timestamp: time.Now(), + }, + Kind: c.kind, + CID: c.cid, + Name: getConnName(c), + Error: ingressError.Error(), + }), + }, + } + t.sendEvent() +} + +// Returns `true` if message tracing is enabled and we are tracing only, +// that is, we are not going to deliver the inbound message, returns +// `false` otherwise (no tracing, or tracing and message delivery). +func (t *msgTrace) traceOnly() bool { + return t != nil && t.tonly +} + +func (t *msgTrace) setOriginAccountHeaderIfNeeded(c *client, acc *Account, msg []byte) []byte { + var oan string + // If t.acc is set, only check that, not t.oan. + if t.acc != nil { + if t.acc != acc { + oan = t.acc.GetName() + } + } else if t.oan != acc.GetName() { + oan = t.oan + } + if oan != _EMPTY_ { + msg = c.setHeader(MsgTraceOriginAccount, oan, msg) + } + return msg +} + +func (t *msgTrace) setHopHeader(c *client, msg []byte) []byte { + e := t.event + e.Hops++ + if len(t.hop) > 0 { + t.nhop = fmt.Sprintf("%s.%d", t.hop, e.Hops) + } else { + t.nhop = fmt.Sprintf("%d", e.Hops) + } + return c.setHeader(MsgTraceHop, t.nhop, msg) +} + +// Will look for the MsgTraceSendTo header and change the first character +// to an 'X' so that if this message is sent to a remote, the remote will +// not initialize tracing since it won't find the actual MsgTraceSendTo +// header. The function returns the position of the header so it can +// efficiently be re-enabled by calling enableTraceHeader. +// Note that if `msg` can be either the header alone or the full message +// (header and payload). This function will use c.pa.hdr to limit the +// search to the header section alone. +func (t *msgTrace) disableTraceHeader(c *client, msg []byte) int { + // Code largely copied from getHeader(), except that we don't need the value + if c.pa.hdr <= 0 { + return -1 + } + hdr := msg[:c.pa.hdr] + key := stringToBytes(MsgTraceSendTo) + pos := bytes.Index(hdr, key) + if pos < 0 { + return -1 + } + // Make sure this key does not have additional prefix. + if pos < 2 || hdr[pos-1] != '\n' || hdr[pos-2] != '\r' { + return -1 + } + index := pos + len(key) + if index >= len(hdr) { + return -1 + } + if hdr[index] != ':' { + return -1 + } + // Disable the trace by altering the first character of the header + hdr[pos] = 'X' + // Return the position of that character so we can re-enable it. + return pos +} + +// Changes back the character at the given position `pos` in the `msg` +// byte slice to the first character of the MsgTraceSendTo header. +func (t *msgTrace) enableTraceHeader(c *client, msg []byte, pos int) { + if pos <= 0 { + return + } + msg[pos] = MsgTraceSendTo[0] +} + +func (t *msgTrace) setIngressError(err string) { + if i := t.event.Ingress(); i != nil { + i.Error = err + } +} + +func (t *msgTrace) addSubjectMappingEvent(subj []byte) { + if t == nil { + return + } + t.event.Events = append(t.event.Events, &MsgTraceSubjectMapping{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceSubjectMappingType, + Timestamp: time.Now(), + }, + MappedTo: string(subj), + }) +} + +func (t *msgTrace) addEgressEvent(dc *client, sub *subscription, err string) { + if t == nil { + return + } + e := &MsgTraceEgress{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceEgressType, + Timestamp: time.Now(), + }, + Kind: dc.kind, + CID: dc.cid, + Name: getConnName(dc), + Hop: t.nhop, + Error: err, + } + t.nhop = _EMPTY_ + // Specific to CLIENT connections... + if dc.kind == CLIENT { + // Set the subscription's subject and possibly queue name. + e.Subscription = string(sub.subject) + if len(sub.queue) > 0 { + e.Queue = string(sub.queue) + } + } + if dc.kind == CLIENT || dc.kind == LEAF { + if i := t.event.Ingress(); i != nil { + // If the Ingress' account is different from the destination's + // account, add the account name into the Egress trace event. + // This would happen with service imports. + if dcAccName := dc.acc.GetName(); dcAccName != i.Account { + e.Account = dcAccName + } + } + } + t.event.Events = append(t.event.Events, e) +} + +func (t *msgTrace) addStreamExportEvent(dc *client, to []byte) { + if t == nil { + return + } + dc.mu.Lock() + accName := dc.acc.GetName() + dc.mu.Unlock() + t.event.Events = append(t.event.Events, &MsgTraceStreamExport{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceStreamExportType, + Timestamp: time.Now(), + }, + Account: accName, + To: string(to), + }) +} + +func (t *msgTrace) addServiceImportEvent(accName, from, to string) { + if t == nil { + return + } + t.event.Events = append(t.event.Events, &MsgTraceServiceImport{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceServiceImportType, + Timestamp: time.Now(), + }, + Account: accName, + From: from, + To: to, + }) +} + +func (t *msgTrace) addJetStreamEvent(streamName string) { + if t == nil { + return + } + t.js = &MsgTraceJetStream{ + MsgTraceBase: MsgTraceBase{ + Type: MsgTraceJetStreamType, + Timestamp: time.Now(), + }, + Stream: streamName, + } + t.event.Events = append(t.event.Events, t.js) +} + +func (t *msgTrace) updateJetStreamEvent(subject string, noInterest bool) { + if t == nil { + return + } + // JetStream event should have been created in addJetStreamEvent + if t.js == nil { + return + } + t.js.Subject = subject + t.js.NoInterest = noInterest +} + +func (t *msgTrace) sendEventFromJetStream(err error) { + if t == nil { + return + } + // JetStream event should have been created in addJetStreamEvent + if t.js == nil { + return + } + if err != nil { + t.js.Error = err.Error() + } + t.sendEvent() +} + +func (t *msgTrace) sendEvent() { + if t == nil { + return + } + if t.js != nil { + ready := atomic.AddInt32(&t.ready, 1) == 2 + if !ready { + return + } + } + t.srv.sendInternalAccountSysMsg(t.acc, t.dest, &t.event.Server, t.event, t.ct) +} diff --git a/server/msgtrace_test.go b/server/msgtrace_test.go new file mode 100644 index 00000000000..222ecb23a49 --- /dev/null +++ b/server/msgtrace_test.go @@ -0,0 +1,4201 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "net" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/klauspost/compress/s2" + "github.com/nats-io/nats.go" +) + +func init() { + msgTraceRunInTests = true +} + +func TestMsgTraceConnName(t *testing.T) { + c := &client{kind: ROUTER, route: &route{remoteName: "somename"}} + c.opts.Name = "someid" + + // If route.remoteName is set, it will take precedence. + val := getConnName(c) + require_Equal[string](t, val, "somename") + // When not set, we revert to c.opts.Name + c.route.remoteName = _EMPTY_ + val = getConnName(c) + require_Equal[string](t, val, "someid") + + // Now same for GW. + c.route = nil + c.gw = &gateway{remoteName: "somename"} + c.kind = GATEWAY + val = getConnName(c) + require_Equal[string](t, val, "somename") + // Revert to c.opts.Name + c.gw.remoteName = _EMPTY_ + val = getConnName(c) + require_Equal[string](t, val, "someid") + + // For LeafNode now + c.gw = nil + c.leaf = &leaf{remoteServer: "somename"} + c.kind = LEAF + val = getConnName(c) + require_Equal[string](t, val, "somename") + // But if not set... + c.leaf.remoteServer = _EMPTY_ + val = getConnName(c) + require_Equal[string](t, val, "someid") + + c.leaf = nil + c.kind = CLIENT + val = getConnName(c) + require_Equal[string](t, val, "someid") +} + +func TestMsgTraceBasic(t *testing.T) { + conf := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + mappings = { + foo: bar + } + `)) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + cid, err := nc.GetClientID() + require_NoError(t, err) + + traceSub := natsSubSync(t, nc, "my.trace.subj") + natsFlush(t, nc) + + // Send trace message to a dummy subject to check that resulting trace's + // SubjectMapping and Egress are nil. + msg := nats.NewMsg("dummy") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + msg.Header.Set(MsgTraceOnly, "true") + msg.Data = []byte("hello!") + err = nc.PublishMsg(msg) + require_NoError(t, err) + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + // We don't remove the headers, so we will find the tracing header there. + require_True(t, e.Request.Header != nil) + require_Equal[int](t, len(e.Request.Header), 2) + // The message size is 6 + whatever size for the 2 trace headers. + // Let's just make sure that size is > 20... + require_True(t, e.Request.MsgSize > 20) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + require_True(t, ingress.Timestamp != time.Time{}) + require_Equal[uint64](t, ingress.CID, cid) + require_Equal[string](t, ingress.Name, _EMPTY_) + require_Equal[string](t, ingress.Account, globalAccountName) + require_Equal[string](t, ingress.Subject, "dummy") + require_Equal[string](t, ingress.Error, _EMPTY_) + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.StreamExports() == nil) + require_True(t, e.ServiceImports() == nil) + require_True(t, e.JetStream() == nil) + require_True(t, e.Egresses() == nil) + + // Now setup subscriptions that generate interest on the subject. + nc2 := natsConnect(t, s.ClientURL(), nats.Name("sub1And2")) + defer nc2.Close() + sub1 := natsSubSync(t, nc2, "bar") + sub2 := natsSubSync(t, nc2, "bar") + natsFlush(t, nc2) + nc2CID, _ := nc2.GetClientID() + + nc3 := natsConnect(t, s.ClientURL()) + defer nc3.Close() + sub3 := natsSubSync(t, nc3, "*") + natsFlush(t, nc3) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg = nats.NewMsg("foo") + msg.Header.Set("Some-App-Header", "some value") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello!") + err = nc.PublishMsg(msg) + require_NoError(t, err) + + checkAppMsg := func(sub *nats.Subscription, expected bool) { + if expected { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "hello!") + // We don't remove message trace header, so we should have + // 2 headers (the app + trace destination) + require_True(t, len(appMsg.Header) == 2) + require_Equal[string](t, appMsg.Header.Get("Some-App-Header"), "some value") + require_Equal[string](t, appMsg.Header.Get(MsgTraceSendTo), traceSub.Subject) + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got %s", msg.Data) + } + } + for _, sub := range []*nats.Subscription{sub1, sub2, sub3} { + checkAppMsg(sub, test.deliverMsg) + } + + traceMsg = natsNexMsg(t, traceSub, time.Second) + e = MsgTraceEvent{} + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + require_True(t, e.Request.Header != nil) + // We should have the app header and the trace header(s) too. + expected := 2 + if !test.deliverMsg { + // The "trace-only" header is added. + expected++ + } + require_Equal[int](t, len(e.Request.Header), expected) + require_Equal[string](t, e.Request.Header.Get("Some-App-Header"), "some value") + // The message size is 6 + whatever size for the 2 trace headers. + // Let's just make sure that size is > 20... + require_True(t, e.Request.MsgSize > 20) + ingress := e.Ingress() + require_True(t, ingress.Kind == CLIENT) + require_True(t, ingress.Timestamp != time.Time{}) + require_Equal[string](t, ingress.Account, globalAccountName) + require_Equal[string](t, ingress.Subject, "foo") + sm := e.SubjectMapping() + require_True(t, sm != nil) + require_True(t, sm.Timestamp != time.Time{}) + require_Equal[string](t, sm.MappedTo, "bar") + egress := e.Egresses() + require_Equal[int](t, len(egress), 3) + var sub1And2 int + for _, eg := range egress { + // All Egress should be clients + require_True(t, eg.Kind == CLIENT) + require_True(t, eg.Timestamp != time.Time{}) + // For nc2CID, we should have two egress + if eg.CID == nc2CID { + // Check name + require_Equal[string](t, eg.Name, "sub1And2") + require_Equal[string](t, eg.Subscription, "bar") + sub1And2++ + } else { + // No name set + require_Equal[string](t, eg.Name, _EMPTY_) + require_Equal[string](t, eg.Subscription, "*") + } + } + require_Equal[int](t, sub1And2, 2) + }) + } +} + +func TestMsgTraceIngressMaxPayloadError(t *testing.T) { + o := DefaultOptions() + o.MaxPayload = 1024 + s := RunServer(o) + defer s.Shutdown() + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + natsSub(t, nc, "foo", func(_ *nats.Msg) {}) + natsFlush(t, nc) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + nc2, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", o.Port)) + require_NoError(t, err) + defer nc2.Close() + + nc2.Write([]byte("CONNECT {\"protocol\":1,\"headers\":true,\"no_responders\":true}\r\n")) + + var traceOnlyHdr string + if !test.deliverMsg { + traceOnlyHdr = fmt.Sprintf("%s:true\r\n", MsgTraceOnly) + } + hdr := fmt.Sprintf("%s%s:%s\r\n%s\r\n", hdrLine, MsgTraceSendTo, traceSub.Subject, traceOnlyHdr) + hPub := fmt.Sprintf("HPUB foo %d 2048\r\n%sAAAAAAAAAAAAAAAAAA...", len(hdr), hdr) + nc2.Write([]byte(hPub)) + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + require_True(t, e.Request.Header == nil) + require_True(t, e.Ingress() != nil) + require_Contains(t, e.Ingress().Error, ErrMaxPayload.Error()) + require_True(t, e.Egresses() == nil) + }) + } +} + +func TestMsgTraceIngressErrors(t *testing.T) { + conf := createConfFile(t, []byte(` + port: -1 + accounts { + A { + users: [ + { + user: a + password: pwd + permissions { + subscribe: ["my.trace.subj", "foo"] + publish { + allow: ["foo", "bar.>"] + deny: ["bar.baz"] + } + } + } + ] + } + } + `)) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd")) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + natsSub(t, nc, "foo", func(_ *nats.Msg) {}) + natsFlush(t, nc) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + nc2 := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd")) + defer nc2.Close() + + sendMsg := func(subj, reply, errTxt string) { + msg := nats.NewMsg(subj) + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Reply = reply + msg.Data = []byte("hello") + nc2.PublishMsg(msg) + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + require_True(t, e.Request.Header != nil) + require_Contains(t, e.Ingress().Error, errTxt) + require_True(t, e.Egresses() == nil) + } + + // Send to a subject that causes permission violation + sendMsg("bar.baz", _EMPTY_, "Permissions Violation for Publish to") + + // Send to a subject that is reserved for GW replies + sendMsg(gwReplyPrefix+"foo", _EMPTY_, "Permissions Violation for Publish to") + + // Send with a Reply that is reserved + sendMsg("foo", replyPrefix+"bar", "Permissions Violation for Publish with Reply of") + }) + } +} + +func TestMsgTraceEgressErrors(t *testing.T) { + conf := createConfFile(t, []byte(` + port: -1 + accounts { + A { + users: [ + { + user: a + password: pwd + permissions { + subscribe: { + allow: ["my.trace.subj", "foo", "bar.>"] + deny: "bar.bat" + } + publish { + allow: ["foo", "bar.>"] + deny: ["bar.baz"] + } + } + } + ] + } + } + `)) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd")) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + natsFlush(t, nc) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + sendMsg := func(pubc *nats.Conn, subj, errTxt string) { + t.Helper() + + msg := nats.NewMsg(subj) + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello") + pubc.PublishMsg(msg) + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + require_Contains(t, egress[0].Error, errTxt) + } + + // Test no-echo. + nc2 := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd"), nats.NoEcho()) + defer nc2.Close() + natsSubSync(t, nc2, "foo") + sendMsg(nc2, "foo", errMsgTraceNoEcho) + nc2.Close() + + // Test deny sub. + nc2 = natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd")) + defer nc2.Close() + natsSubSync(t, nc2, "bar.>") + sendMsg(nc2, "bar.bat", errMsgTraceSubDeny) + nc2.Close() + + // Test sub closed + nc2 = natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd")) + defer nc2.Close() + natsSubSync(t, nc2, "bar.>") + natsFlush(t, nc2) + // Aritifially change the closed status of the subscription + cid, err := nc2.GetClientID() + require_NoError(t, err) + c := s.GetClient(cid) + c.mu.Lock() + for _, sub := range c.subs { + if string(sub.subject) == "bar.>" { + sub.close() + } + } + c.mu.Unlock() + sendMsg(nc2, "bar.bar", errMsgTraceSubClosed) + nc2.Close() + + // The following applies only when doing delivery. + if test.deliverMsg { + // Test auto-unsub exceeded + nc2 = natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd")) + defer nc2.Close() + sub := natsSubSync(t, nc2, "bar.>") + err := sub.AutoUnsubscribe(10) + require_NoError(t, err) + natsFlush(t, nc2) + + // Modify sub.nm to be already over the 10 limit + cid, err := nc2.GetClientID() + require_NoError(t, err) + c := s.GetClient(cid) + c.mu.Lock() + for _, sub := range c.subs { + if string(sub.subject) == "bar.>" { + sub.nm = 20 + } + } + c.mu.Unlock() + + sendMsg(nc2, "bar.bar", errMsgTraceAutoSubExceeded) + nc2.Close() + + // Test client closed + nc2 = natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd")) + defer nc2.Close() + natsSubSync(t, nc2, "bar.>") + cid, err = nc2.GetClientID() + require_NoError(t, err) + c = s.GetClient(cid) + c.mu.Lock() + c.out.stc = make(chan struct{}) + c.mu.Unlock() + msg := nats.NewMsg("bar.bar") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello") + nc2.PublishMsg(msg) + time.Sleep(10 * time.Millisecond) + cid, err = nc2.GetClientID() + require_NoError(t, err) + c = s.GetClient(cid) + c.mu.Lock() + c.flags.set(closeConnection) + c.mu.Unlock() + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + require_Contains(t, egress[0].Error, errMsgTraceClientClosed) + c.mu.Lock() + c.flags.clear(closeConnection) + c.mu.Unlock() + nc2.Close() + } + }) + } +} + +func TestMsgTraceWithQueueSub(t *testing.T) { + o := DefaultOptions() + s := RunServer(o) + defer s.Shutdown() + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + natsFlush(t, nc) + + nc2 := natsConnect(t, s.ClientURL(), nats.Name("sub1")) + defer nc2.Close() + sub1 := natsQueueSubSync(t, nc2, "foo", "bar") + natsFlush(t, nc2) + + nc3 := natsConnect(t, s.ClientURL(), nats.Name("sub2")) + defer nc3.Close() + sub2 := natsQueueSubSync(t, nc3, "foo", "bar") + sub3 := natsQueueSubSync(t, nc3, "*", "baz") + natsFlush(t, nc3) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + if !test.deliverMsg { + msg.Data = []byte("hello1") + } else { + msg.Data = []byte("hello2") + } + err := nc.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + // Only one should have got the message... + msg1, err1 := sub1.NextMsg(100 * time.Millisecond) + msg2, err2 := sub2.NextMsg(100 * time.Millisecond) + if err1 == nil && err2 == nil { + t.Fatalf("Only one message should have been received") + } + var val string + if msg1 != nil { + val = string(msg1.Data) + } else { + val = string(msg2.Data) + } + require_Equal[string](t, val, "hello2") + // Queue baz should also have received the message + msg := natsNexMsg(t, sub3, time.Second) + require_Equal[string](t, string(msg.Data), "hello2") + } + // Check that no (more) messages are received. + for _, sub := range []*nats.Subscription{sub1, sub2, sub3} { + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Expected no message, got %s", msg.Data) + } + } + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, ingress.Subject, "foo") + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + var qbar, qbaz int + for _, eg := range egress { + switch eg.Queue { + case "bar": + require_Equal[string](t, eg.Subscription, "foo") + qbar++ + case "baz": + require_Equal[string](t, eg.Subscription, "*") + qbaz++ + default: + t.Fatalf("Wrong queue name: %q", eg.Queue) + } + } + require_Equal[int](t, qbar, 1) + require_Equal[int](t, qbaz, 1) + }) + } +} + +func TestMsgTraceWithRoutes(t *testing.T) { + tmpl := ` + port: -1 + accounts { + A { users: [{user:A, password: pwd}] } + B { users: [{user:B, password: pwd}] } + } + cluster { + name: "local" + port: -1 + accounts: ["A"] + %s + } + ` + conf1 := createConfFile(t, []byte(fmt.Sprintf(tmpl, _EMPTY_))) + s1, o1 := RunServerWithConfig(conf1) + defer s1.Shutdown() + + conf2 := createConfFile(t, []byte(fmt.Sprintf(tmpl, fmt.Sprintf("routes: [\"nats://127.0.0.1:%d\"]", o1.Cluster.Port)))) + s2, _ := RunServerWithConfig(conf2) + defer s2.Shutdown() + + checkClusterFormed(t, s1, s2) + + checkDummy := func(user string) { + nc := natsConnect(t, s1.ClientURL(), nats.UserInfo(user, "pwd")) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + natsFlush(t, nc) + + // Send trace message to a dummy subject to check that resulting trace + // is as expected. + msg := nats.NewMsg("dummy") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + msg.Header.Set(MsgTraceOnly, "true") + msg.Data = []byte("hello!") + err := nc.PublishMsg(msg) + require_NoError(t, err) + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s1.Name()) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + // "user" is same than account name in this test. + require_Equal[string](t, ingress.Account, user) + require_Equal[string](t, ingress.Subject, "dummy") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.Egresses() == nil) + + // We should also not get an event from the remote server. + if msg, err := traceSub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Expected no message, got %s", msg.Data) + } + } + checkDummy("A") + checkDummy("B") + + for _, test := range []struct { + name string + acc string + }{ + {"pinned account", "A"}, + {"reg account", "B"}, + } { + t.Run(test.name, func(t *testing.T) { + acc := test.acc + // Now create subscriptions on both s1 and s2 + nc2 := natsConnect(t, s2.ClientURL(), nats.UserInfo(acc, "pwd"), nats.Name("sub2")) + defer nc2.Close() + sub2 := natsQueueSubSync(t, nc2, "foo.*", "my_queue") + + nc3 := natsConnect(t, s2.ClientURL(), nats.UserInfo(acc, "pwd"), nats.Name("sub3")) + defer nc3.Close() + sub3 := natsQueueSubSync(t, nc3, "*.*", "my_queue_2") + + checkSubInterest(t, s1, acc, "foo.bar", time.Second) + + nc1 := natsConnect(t, s1.ClientURL(), nats.UserInfo(acc, "pwd"), nats.Name("sub1")) + defer nc1.Close() + sub1 := natsSubSync(t, nc1, "*.bar") + + nct := natsConnect(t, s1.ClientURL(), nats.UserInfo(acc, "pwd"), nats.Name("tracer")) + defer nct.Close() + traceSub := natsSubSync(t, nct, "my.trace.subj") + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo.bar") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello!") + err := nct.PublishMsg(msg) + require_NoError(t, err) + + checkAppMsg := func(sub *nats.Subscription, expected bool) { + if expected { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "hello!") + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got %s", msg.Data) + } + } + for _, sub := range []*nats.Subscription{sub1, sub2, sub3} { + checkAppMsg(sub, test.deliverMsg) + } + + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + ingress := e.Ingress() + require_True(t, ingress != nil) + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, s1.Name()) + require_Equal[string](t, ingress.Account, acc) + require_Equal[string](t, ingress.Subject, "foo.bar") + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + if eg.Kind == CLIENT { + require_Equal[string](t, eg.Name, "sub1") + require_Equal[string](t, eg.Subscription, "*.bar") + require_Equal[string](t, eg.Queue, _EMPTY_) + } else { + require_True(t, eg.Kind == ROUTER) + require_Equal[string](t, eg.Name, s2.Name()) + require_Equal[string](t, eg.Subscription, _EMPTY_) + require_Equal[string](t, eg.Queue, _EMPTY_) + } + } + case ROUTER: + require_Equal[string](t, e.Server.Name, s2.Name()) + require_Equal[string](t, ingress.Account, acc) + require_Equal[string](t, ingress.Subject, "foo.bar") + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + var gotSub2, gotSub3 int + for _, eg := range egress { + require_True(t, eg.Kind == CLIENT) + switch eg.Name { + case "sub2": + require_Equal[string](t, eg.Subscription, "foo.*") + require_Equal[string](t, eg.Queue, "my_queue") + gotSub2++ + case "sub3": + require_Equal[string](t, eg.Subscription, "*.*") + require_Equal[string](t, eg.Queue, "my_queue_2") + gotSub3++ + default: + t.Fatalf("Unexpected egress name: %+v", eg) + } + } + require_Equal[int](t, gotSub2, 1) + require_Equal[int](t, gotSub3, 1) + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + // We should get 2 events. Order is not guaranteed. + check() + check() + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + }) + } + }) + } +} + +func TestMsgTraceWithRouteToOldServer(t *testing.T) { + msgTraceCheckSupport = true + defer func() { msgTraceCheckSupport = false }() + tmpl := ` + port: -1 + cluster { + name: "local" + port: -1 + pool_size: -1 + %s + } + ` + conf1 := createConfFile(t, []byte(fmt.Sprintf(tmpl, _EMPTY_))) + s1, o1 := RunServerWithConfig(conf1) + defer s1.Shutdown() + + conf2 := createConfFile(t, []byte(fmt.Sprintf(tmpl, fmt.Sprintf("routes: [\"nats://127.0.0.1:%d\"]", o1.Cluster.Port)))) + o2 := LoadConfig(conf2) + // Make this server behave like an older server + o2.overrideProto = setServerProtoForTest(MsgTraceProto - 1) + s2 := RunServer(o2) + defer s2.Shutdown() + + checkClusterFormed(t, s1, s2) + + // Now create subscriptions on both s1 and s2 + nc2 := natsConnect(t, s2.ClientURL(), nats.Name("sub2")) + defer nc2.Close() + sub2 := natsSubSync(t, nc2, "foo") + + checkSubInterest(t, s1, globalAccountName, "foo", time.Second) + + nc1 := natsConnect(t, s1.ClientURL(), nats.Name("sub1")) + defer nc1.Close() + sub1 := natsSubSync(t, nc1, "foo") + + nct := natsConnect(t, s1.ClientURL(), nats.Name("tracer")) + defer nct.Close() + traceSub := natsSubSync(t, nct, "my.trace.subj") + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello!") + err := nct.PublishMsg(msg) + require_NoError(t, err) + + checkAppMsg := func(sub *nats.Subscription, expected bool) { + if expected { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "hello!") + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got %s", msg.Data) + } + } + // Even if a server does not support tracing, as long as the header + // TraceOnly is not set, the message should be forwarded to the remote. + for _, sub := range []*nats.Subscription{sub1, sub2} { + checkAppMsg(sub, test.deliverMsg) + } + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, e.Server.Name, s1.Name()) + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, ci := range egress { + switch ci.Kind { + case CLIENT: + require_Equal[string](t, ci.Name, "sub1") + case ROUTER: + require_Equal[string](t, ci.Name, s2.Name()) + if test.deliverMsg { + require_Contains(t, ci.Error, errMsgTraceNoSupport) + } else { + require_Contains(t, ci.Error, errMsgTraceOnlyNoSupport) + } + default: + t.Fatalf("Unexpected egress: %+v", ci) + } + } + // We should not get a second trace + if msg, err := traceSub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect other trace, got %s", msg.Data) + } + }) + } +} + +func TestMsgTraceWithLeafNode(t *testing.T) { + for _, mainTest := range []struct { + name string + fromHub bool + leafUseLocalAcc bool + }{ + {"from hub", true, false}, + {"from leaf", false, false}, + {"from hub with local account", true, true}, + {"from leaf with local account", false, true}, + } { + t.Run(mainTest.name, func(t *testing.T) { + confHub := createConfFile(t, []byte(` + port: -1 + server_name: "A" + accounts { + A { users: [{user: "a", password: "pwd"}]} + B { users: [{user: "b", password: "pwd"}]} + } + leafnodes { + port: -1 + } + `)) + hub, ohub := RunServerWithConfig(confHub) + defer hub.Shutdown() + + var accs string + var lacc string + if mainTest.leafUseLocalAcc { + accs = `accounts { B { users: [{user: "b", password: "pwd"}]} }` + lacc = `account: B` + } + confLeaf := createConfFile(t, []byte(fmt.Sprintf(` + port: -1 + server_name: "B" + %s + leafnodes { + remotes [ + { + url: "nats://a:pwd@127.0.0.1:%d" + %s + } + ] + } + `, accs, ohub.LeafNode.Port, lacc))) + leaf, _ := RunServerWithConfig(confLeaf) + defer leaf.Shutdown() + + checkLeafNodeConnected(t, hub) + checkLeafNodeConnected(t, leaf) + + var s1, s2 *Server + if mainTest.fromHub { + s1, s2 = hub, leaf + } else { + s1, s2 = leaf, hub + } + // Now create subscriptions on both s1 and s2 + opts := []nats.Option{nats.Name("sub2")} + var user string + // If fromHub, then it means that s2 is the leaf. + if mainTest.fromHub { + if mainTest.leafUseLocalAcc { + user = "b" + } + } else { + // s2 is the hub, always connect with user "a'" + user = "a" + } + if user != _EMPTY_ { + opts = append(opts, nats.UserInfo(user, "pwd")) + } + nc2 := natsConnect(t, s2.ClientURL(), opts...) + defer nc2.Close() + sub2 := natsSubSync(t, nc2, "foo") + + if mainTest.fromHub { + checkSubInterest(t, s1, "A", "foo", time.Second) + } else if mainTest.leafUseLocalAcc { + checkSubInterest(t, s1, "B", "foo", time.Second) + } else { + checkSubInterest(t, s1, globalAccountName, "foo", time.Second) + } + + user = _EMPTY_ + opts = []nats.Option{nats.Name("sub1")} + if mainTest.fromHub { + // s1 is the hub, so we need user "a" + user = "a" + } else if mainTest.leafUseLocalAcc { + // s1 is the leaf, we need user "b" if leafUseLocalAcc + user = "b" + } + if user != _EMPTY_ { + opts = append(opts, nats.UserInfo(user, "pwd")) + } + nc1 := natsConnect(t, s1.ClientURL(), opts...) + defer nc1.Close() + sub1 := natsSubSync(t, nc1, "foo") + + opts = []nats.Option{nats.Name("tracer")} + if user != _EMPTY_ { + opts = append(opts, nats.UserInfo(user, "pwd")) + } + nct := natsConnect(t, s1.ClientURL(), opts...) + defer nct.Close() + traceSub := natsSubSync(t, nct, "my.trace.subj") + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello!") + err := nct.PublishMsg(msg) + require_NoError(t, err) + + checkAppMsg := func(sub *nats.Subscription, expected bool) { + if expected { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "hello!") + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got %s", msg.Data) + } + } + for _, sub := range []*nats.Subscription{sub1, sub2} { + checkAppMsg(sub, test.deliverMsg) + } + + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + ingress := e.Ingress() + require_True(t, ingress != nil) + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, s1.Name()) + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + switch eg.Kind { + case CLIENT: + require_Equal[string](t, eg.Name, "sub1") + case LEAF: + require_Equal[string](t, eg.Name, s2.Name()) + require_Equal[string](t, eg.Error, _EMPTY_) + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + case LEAF: + require_Equal[string](t, e.Server.Name, s2.Name()) + require_True(t, ingress.Kind == LEAF) + require_Equal(t, ingress.Name, s1.Name()) + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == CLIENT) + require_Equal[string](t, eg.Name, "sub2") + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + check() + check() + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + }) + } + }) + } +} + +func TestMsgTraceWithLeafNodeToOldServer(t *testing.T) { + msgTraceCheckSupport = true + defer func() { msgTraceCheckSupport = false }() + for _, mainTest := range []struct { + name string + fromHub bool + }{ + {"from hub", true}, + {"from leaf", false}, + } { + t.Run(mainTest.name, func(t *testing.T) { + confHub := createConfFile(t, []byte(` + port: -1 + server_name: "A" + leafnodes { + port: -1 + } + `)) + ohub := LoadConfig(confHub) + if !mainTest.fromHub { + // Make this server behave like an older server + ohub.overrideProto = setServerProtoForTest(MsgTraceProto - 1) + } + hub := RunServer(ohub) + defer hub.Shutdown() + + confLeaf := createConfFile(t, []byte(fmt.Sprintf(` + port: -1 + server_name: "B" + leafnodes { + remotes [{url: "nats://127.0.0.1:%d"}] + } + `, ohub.LeafNode.Port))) + oleaf := LoadConfig(confLeaf) + if mainTest.fromHub { + // Make this server behave like an older server + oleaf.overrideProto = setServerProtoForTest(MsgTraceProto - 1) + } + leaf := RunServer(oleaf) + defer leaf.Shutdown() + + checkLeafNodeConnected(t, hub) + checkLeafNodeConnected(t, leaf) + + var s1, s2 *Server + if mainTest.fromHub { + s1, s2 = hub, leaf + } else { + s1, s2 = leaf, hub + } + + // Now create subscriptions on both s1 and s2 + nc2 := natsConnect(t, s2.ClientURL(), nats.Name("sub2")) + defer nc2.Close() + sub2 := natsSubSync(t, nc2, "foo") + + checkSubInterest(t, s1, globalAccountName, "foo", time.Second) + + nc1 := natsConnect(t, s1.ClientURL(), nats.Name("sub1")) + defer nc1.Close() + sub1 := natsSubSync(t, nc1, "foo") + + nct := natsConnect(t, s1.ClientURL(), nats.Name("tracer")) + defer nct.Close() + traceSub := natsSubSync(t, nct, "my.trace.subj") + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello!") + err := nct.PublishMsg(msg) + require_NoError(t, err) + + checkAppMsg := func(sub *nats.Subscription, expected bool) { + if expected { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "hello!") + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got %s", msg.Data) + } + } + // Even if a server does not support tracing, as long as the header + // TraceOnly is not set, the message should be forwarded to the remote. + for _, sub := range []*nats.Subscription{sub1, sub2} { + checkAppMsg(sub, test.deliverMsg) + } + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, e.Server.Name, s1.Name()) + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, ci := range egress { + switch ci.Kind { + case CLIENT: + require_Equal[string](t, ci.Name, "sub1") + case LEAF: + require_Equal[string](t, ci.Name, s2.Name()) + if test.deliverMsg { + require_Contains(t, ci.Error, errMsgTraceNoSupport) + } else { + require_Contains(t, ci.Error, errMsgTraceOnlyNoSupport) + } + default: + t.Fatalf("Unexpected egress: %+v", ci) + } + } + // We should not get a second trace + if msg, err := traceSub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect other trace, got %s", msg.Data) + } + }) + } + }) + } +} + +func TestMsgTraceWithLeafNodeDaisyChain(t *testing.T) { + confHub := createConfFile(t, []byte(` + port: -1 + server_name: "A" + accounts { + A { users: [{user: "a", password: "pwd"}]} + } + leafnodes { + port: -1 + } + `)) + hub, ohub := RunServerWithConfig(confHub) + defer hub.Shutdown() + + confLeaf1 := createConfFile(t, []byte(fmt.Sprintf(` + port: -1 + server_name: "B" + accounts { + B { users: [{user: "b", password: "pwd"}]} + } + leafnodes { + port: -1 + remotes [{url: "nats://a:pwd@127.0.0.1:%d", account: B}] + } + `, ohub.LeafNode.Port))) + leaf1, oleaf1 := RunServerWithConfig(confLeaf1) + defer leaf1.Shutdown() + + confLeaf2 := createConfFile(t, []byte(fmt.Sprintf(` + port: -1 + server_name: "C" + accounts { + C { users: [{user: "c", password: "pwd"}]} + } + leafnodes { + remotes [{url: "nats://b:pwd@127.0.0.1:%d", account: C}] + } + `, oleaf1.LeafNode.Port))) + leaf2, _ := RunServerWithConfig(confLeaf2) + defer leaf2.Shutdown() + + checkLeafNodeConnected(t, hub) + checkLeafNodeConnectedCount(t, leaf1, 2) + checkLeafNodeConnected(t, leaf2) + + nct := natsConnect(t, hub.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("Tracer")) + defer nct.Close() + traceSub := natsSubSync(t, nct, "my.trace.subj") + natsFlush(t, nct) + // Make sure that subject interest travels down to leaf2 + checkSubInterest(t, leaf2, "C", traceSub.Subject, time.Second) + + nc1 := natsConnect(t, leaf1.ClientURL(), nats.UserInfo("b", "pwd"), nats.Name("sub1")) + defer nc1.Close() + + nc2 := natsConnect(t, leaf2.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("sub2")) + defer nc2.Close() + sub2 := natsQueueSubSync(t, nc2, "foo.bar", "my_queue") + natsFlush(t, nc2) + + // Check the the subject interest makes it to leaf1 + checkSubInterest(t, leaf1, "B", "foo.bar", time.Second) + + // Now create the sub on leaf1 + sub1 := natsSubSync(t, nc1, "foo.*") + natsFlush(t, nc1) + + // Check that subject interest registered on "hub" + checkSubInterest(t, hub, "A", "foo.bar", time.Second) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo.bar") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello!") + err := nct.PublishMsg(msg) + require_NoError(t, err) + + checkAppMsg := func(sub *nats.Subscription, expected bool) { + if expected { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "hello!") + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got %s", msg.Data) + } + } + for _, sub := range []*nats.Subscription{sub1, sub2} { + checkAppMsg(sub, test.deliverMsg) + } + + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + ingress := e.Ingress() + require_True(t, ingress != nil) + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, hub.Name()) + require_Equal[string](t, ingress.Name, "Tracer") + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "foo.bar") + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == LEAF) + require_Equal[string](t, eg.Name, leaf1.Name()) + require_Equal[string](t, eg.Account, _EMPTY_) + require_Equal[string](t, eg.Subscription, _EMPTY_) + case LEAF: + switch e.Server.Name { + case leaf1.Name(): + require_Equal(t, ingress.Name, hub.Name()) + require_Equal(t, ingress.Account, "B") + require_Equal[string](t, ingress.Subject, "foo.bar") + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + switch eg.Kind { + case CLIENT: + require_Equal[string](t, eg.Name, "sub1") + require_Equal[string](t, eg.Subscription, "foo.*") + require_Equal[string](t, eg.Queue, _EMPTY_) + case LEAF: + require_Equal[string](t, eg.Name, leaf2.Name()) + require_Equal[string](t, eg.Account, _EMPTY_) + require_Equal[string](t, eg.Subscription, _EMPTY_) + require_Equal[string](t, eg.Queue, _EMPTY_) + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + case leaf2.Name(): + require_Equal(t, ingress.Name, leaf1.Name()) + require_Equal(t, ingress.Account, "C") + require_Equal[string](t, ingress.Subject, "foo.bar") + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == CLIENT) + require_Equal[string](t, eg.Name, "sub2") + require_Equal[string](t, eg.Subscription, "foo.bar") + require_Equal[string](t, eg.Queue, "my_queue") + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + check() + check() + check() + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + }) + } +} + +func TestMsgTraceWithGateways(t *testing.T) { + o2 := testDefaultOptionsForGateway("B") + o2.NoSystemAccount = false + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + o1.NoSystemAccount = false + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForInboundGateways(t, s2, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + nc2 := natsConnect(t, s2.ClientURL(), nats.Name("sub2")) + defer nc2.Close() + sub2 := natsQueueSubSync(t, nc2, "foo.*", "my_queue") + + nc3 := natsConnect(t, s2.ClientURL(), nats.Name("sub3")) + defer nc3.Close() + sub3 := natsQueueSubSync(t, nc3, "*.*", "my_queue_2") + + nc1 := natsConnect(t, s1.ClientURL(), nats.Name("sub1")) + defer nc1.Close() + sub1 := natsSubSync(t, nc1, "*.bar") + + nct := natsConnect(t, s1.ClientURL(), nats.Name("tracer")) + defer nct.Close() + traceSub := natsSubSync(t, nct, "my.trace.subj") + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo.bar") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello!") + err := nct.PublishMsg(msg) + require_NoError(t, err) + + checkAppMsg := func(sub *nats.Subscription, expected bool) { + if expected { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "hello!") + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got %s", msg.Data) + } + } + for _, sub := range []*nats.Subscription{sub1, sub2, sub3} { + checkAppMsg(sub, test.deliverMsg) + } + + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + ingress := e.Ingress() + require_True(t, ingress != nil) + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, s1.Name()) + require_Equal[string](t, ingress.Account, globalAccountName) + require_Equal[string](t, ingress.Subject, "foo.bar") + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + switch eg.Kind { + case CLIENT: + require_Equal[string](t, eg.Name, "sub1") + require_Equal[string](t, eg.Subscription, "*.bar") + require_Equal[string](t, eg.Queue, _EMPTY_) + case GATEWAY: + require_Equal[string](t, eg.Name, s2.Name()) + require_Equal[string](t, eg.Error, _EMPTY_) + require_Equal[string](t, eg.Subscription, _EMPTY_) + require_Equal[string](t, eg.Queue, _EMPTY_) + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + case GATEWAY: + require_Equal[string](t, e.Server.Name, s2.Name()) + require_Equal[string](t, ingress.Account, globalAccountName) + require_Equal[string](t, ingress.Subject, "foo.bar") + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + var gotSub2, gotSub3 int + for _, eg := range egress { + require_True(t, eg.Kind == CLIENT) + switch eg.Name { + case "sub2": + require_Equal[string](t, eg.Subscription, "foo.*") + require_Equal[string](t, eg.Queue, "my_queue") + gotSub2++ + case "sub3": + require_Equal[string](t, eg.Subscription, "*.*") + require_Equal[string](t, eg.Queue, "my_queue_2") + gotSub3++ + default: + t.Fatalf("Unexpected egress name: %+v", eg) + } + } + require_Equal[int](t, gotSub2, 1) + require_Equal[int](t, gotSub3, 1) + + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + // We should get 2 events + check() + check() + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + }) + } +} + +func TestMsgTraceWithGatewayToOldServer(t *testing.T) { + msgTraceCheckSupport = true + defer func() { msgTraceCheckSupport = false }() + + o2 := testDefaultOptionsForGateway("B") + o2.NoSystemAccount = false + // Make this server behave like an older server + o2.overrideProto = setServerProtoForTest(MsgTraceProto - 1) + s2 := runGatewayServer(o2) + defer s2.Shutdown() + + o1 := testGatewayOptionsFromToWithServers(t, "A", "B", s2) + o1.NoSystemAccount = false + s1 := runGatewayServer(o1) + defer s1.Shutdown() + + waitForOutboundGateways(t, s1, 1, time.Second) + waitForInboundGateways(t, s2, 1, time.Second) + waitForOutboundGateways(t, s2, 1, time.Second) + + nc2 := natsConnect(t, s2.ClientURL(), nats.Name("sub2")) + defer nc2.Close() + sub2 := natsSubSync(t, nc2, "foo") + + nc1 := natsConnect(t, s1.ClientURL(), nats.Name("sub1")) + defer nc1.Close() + sub1 := natsSubSync(t, nc1, "foo") + + nct := natsConnect(t, s1.ClientURL(), nats.Name("tracer")) + defer nct.Close() + traceSub := natsSubSync(t, nct, "my.trace.subj") + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello!") + err := nct.PublishMsg(msg) + require_NoError(t, err) + + checkAppMsg := func(sub *nats.Subscription, expected bool) { + if expected { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "hello!") + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got %s", msg.Data) + } + } + // Even if a server does not support tracing, as long as the header + // TraceOnly is not set, the message should be forwarded to the remote. + for _, sub := range []*nats.Subscription{sub1, sub2} { + checkAppMsg(sub, test.deliverMsg) + } + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + ingress := e.Ingress() + require_True(t, ingress != nil) + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, s1.Name()) + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, ci := range egress { + switch ci.Kind { + case CLIENT: + require_Equal[string](t, ci.Name, "sub1") + case GATEWAY: + require_Equal[string](t, ci.Name, s2.Name()) + if test.deliverMsg { + require_Contains(t, ci.Error, errMsgTraceNoSupport) + } else { + require_Contains(t, ci.Error, errMsgTraceOnlyNoSupport) + } + default: + t.Fatalf("Unexpected egress: %+v", ci) + } + } + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + // We should not get a second trace + if msg, err := traceSub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect other trace, got %s", msg.Data) + } + }) + } +} + +func TestMsgTraceServiceImport(t *testing.T) { + for _, mainTest := range []struct { + name string + allow bool + }{ + {"allowed", true}, + {"not allowed", false}, + } { + t.Run(mainTest.name, func(t *testing.T) { + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:-1 + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ { service: ">", allow_trace: %v} ] + mappings = { + bar: bozo + } + } + B { + users: [{user: b, password: pwd}] + imports: [ { service: {account: "A", subject:">"} } ] + exports: [ { service: ">", allow_trace: %v} ] + } + C { + users: [{user: c, password: pwd}] + exports: [ { service: ">", allow_trace: %v } ] + } + D { + users: [{user: d, password: pwd}] + imports: [ + { service: {account: "B", subject:"bar"}, to: baz } + { service: {account: "C", subject:">"} } + ] + mappings = { + bat: baz + } + } + } + `, mainTest.allow, mainTest.allow, mainTest.allow))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc := natsConnect(t, s.ClientURL(), nats.UserInfo("d", "pwd"), nats.Name("Requestor")) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + sub := natsSubSync(t, nc, "my.service.response.inbox") + + nc2 := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("ServiceA")) + defer nc2.Close() + recv := int32(0) + natsQueueSub(t, nc2, "*", "my_queue", func(m *nats.Msg) { + atomic.AddInt32(&recv, 1) + m.Respond(m.Data) + }) + natsFlush(t, nc2) + + nc3 := natsConnect(t, s.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("ServiceC")) + defer nc3.Close() + natsSub(t, nc3, "baz", func(m *nats.Msg) { + atomic.AddInt32(&recv, 1) + m.Respond(m.Data) + }) + natsFlush(t, nc3) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("bat") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + if !test.deliverMsg { + msg.Data = []byte("request1") + } else { + msg.Data = []byte("request2") + } + msg.Reply = sub.Subject + + err := nc.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + for i := 0; i < 2; i++ { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "request2") + } + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); msg != nil || err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got msg=%v err=%v", msg, err) + } + if !test.deliverMsg { + // Just to make sure that message was not delivered to service + // responders, wait a bit and check the recv value. + time.Sleep(50 * time.Millisecond) + if n := atomic.LoadInt32(&recv); n != 0 { + t.Fatalf("Expected no message to be received, but service callback fired %d times", n) + } + } + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + require_Equal[string](t, e.Server.Name, s.Name()) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, ingress.Account, "D") + require_Equal[string](t, ingress.Subject, "bat") + sm := e.SubjectMapping() + require_True(t, sm != nil) + require_Equal[string](t, sm.MappedTo, "baz") + simps := e.ServiceImports() + require_True(t, simps != nil) + var expectedServices int + if mainTest.allow { + expectedServices = 3 + } else { + expectedServices = 2 + } + require_Equal[int](t, len(simps), expectedServices) + for _, si := range simps { + require_True(t, si.Timestamp != time.Time{}) + switch si.Account { + case "C": + require_Equal[string](t, si.From, "baz") + require_Equal[string](t, si.To, "baz") + case "B": + require_Equal[string](t, si.From, "baz") + require_Equal[string](t, si.To, "bar") + case "A": + if !mainTest.allow { + t.Fatalf("Without allow_trace, we should not see service for account A") + } + require_Equal[string](t, si.From, "bar") + require_Equal[string](t, si.To, "bozo") + default: + t.Fatalf("Unexpected account name: %s", si.Account) + } + } + egress := e.Egresses() + if !mainTest.allow { + require_Equal[int](t, len(egress), 0) + } else { + require_Equal[int](t, len(egress), 2) + var gotA, gotC bool + for _, eg := range egress { + // All Egress should be clients + require_True(t, eg.Kind == CLIENT) + // We should have one for ServiceA and one for ServiceC + if eg.Name == "ServiceA" { + require_Equal[string](t, eg.Account, "A") + require_Equal[string](t, eg.Subscription, "*") + require_Equal[string](t, eg.Queue, "my_queue") + gotA = true + } else if eg.Name == "ServiceC" { + require_Equal[string](t, eg.Account, "C") + require_Equal[string](t, eg.Queue, _EMPTY_) + gotC = true + } + } + if !gotA { + t.Fatalf("Did not get Egress for serviceA: %+v", egress) + } + if !gotC { + t.Fatalf("Did not get Egress for serviceC: %+v", egress) + } + } + + // Make sure we properly remove the responses. + checkResp := func(an string) { + acc, err := s.lookupAccount(an) + require_NoError(t, err) + checkFor(t, time.Second, 15*time.Millisecond, func() error { + if n := acc.NumPendingAllResponses(); n != 0 { + return fmt.Errorf("Still %d responses pending for account %q on server %s", n, acc, s) + } + return nil + }) + } + for _, acc := range []string{"A", "B", "C", "D"} { + checkResp(acc) + } + }) + } + }) + } +} + +func TestMsgTraceServiceImportWithSuperCluster(t *testing.T) { + for _, mainTest := range []struct { + name string + allowStr string + allow bool + }{ + {"allowed", "true", true}, + {"not allowed", "false", false}, + } { + t.Run(mainTest.name, func(t *testing.T) { + tmpl := ` + listen: 127.0.0.1:-1 + server_name: %s + jetstream: {max_mem_store: 256MB, max_file_store: 2GB, store_dir: '%s'} + + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ { service: ">", allow_trace: ` + mainTest.allowStr + ` } ] + mappings = { + bar: bozo + } + } + B { + users: [{user: b, password: pwd}] + imports: [ { service: {account: "A", subject:">"} } ] + exports: [ { service: ">" , allow_trace: ` + mainTest.allowStr + ` } ] + } + C { + users: [{user: c, password: pwd}] + exports: [ { service: ">" , allow_trace: ` + mainTest.allowStr + ` } ] + } + D { + users: [{user: d, password: pwd}] + imports: [ + { service: {account: "B", subject:"bar"}, to: baz } + { service: {account: "C", subject:">"} } + ] + mappings = { + bat: baz + } + } + $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } + } + ` + sc := createJetStreamSuperClusterWithTemplate(t, tmpl, 3, 2) + defer sc.shutdown() + + sfornc := sc.clusters[0].servers[0] + nc := natsConnect(t, sfornc.ClientURL(), nats.UserInfo("d", "pwd"), nats.Name("Requestor")) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + sub := natsSubSync(t, nc, "my.service.response.inbox") + + sfornc2 := sc.clusters[0].servers[1] + nc2 := natsConnect(t, sfornc2.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("ServiceA")) + defer nc2.Close() + subSvcA := natsQueueSubSync(t, nc2, "*", "my_queue") + natsFlush(t, nc2) + + sfornc3 := sc.clusters[1].servers[0] + nc3 := natsConnect(t, sfornc3.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("ServiceC")) + defer nc3.Close() + subSvcC := natsSubSync(t, nc3, "baz") + natsFlush(t, nc3) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("bat") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + if !test.deliverMsg { + msg.Data = []byte("request1") + } else { + msg.Data = []byte("request2") + } + msg.Reply = sub.Subject + + err := nc.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + processSvc := func(sub *nats.Subscription) { + t.Helper() + appMsg := natsNexMsg(t, sub, time.Second) + // This test causes a message to be routed to the + // service responders. When not allowing, we need + // to make sure that the trace header has been + // disabled. Not receiving the trace event from + // the remote is not enough to verify since the + // trace would not reach the origin server because + // the origin account header will not be present. + if mainTest.allow { + if hv := appMsg.Header.Get(MsgTraceSendTo); hv != traceSub.Subject { + t.Fatalf("Expecting header with %q, but got %q", traceSub.Subject, hv) + } + } else { + if hv := appMsg.Header.Get(MsgTraceSendTo); hv != _EMPTY_ { + t.Fatalf("Expecting no header, but header was present with value: %q", hv) + } + // We don't really need to check that, but we + // should see the header with the first letter + // being an `X`. + hnb := []byte(MsgTraceSendTo) + hnb[0] = 'X' + hn := string(hnb) + if hv := appMsg.Header.Get(hn); hv != traceSub.Subject { + t.Fatalf("Expected header %q to be %q, got %q", hn, traceSub.Subject, hv) + } + } + appMsg.Respond(appMsg.Data) + } + processSvc(subSvcA) + processSvc(subSvcC) + + for i := 0; i < 2; i++ { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "request2") + } + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); msg != nil || err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got msg=%v err=%v", msg, err) + } + if !test.deliverMsg { + // Just to make sure that message was not delivered to service + // responders, wait a bit and check the recv value. + time.Sleep(50 * time.Millisecond) + for _, sub := range []*nats.Subscription{subSvcA, subSvcC} { + if msg, err := sub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Expected no message to be received, but service subscription got %s", msg.Data) + } + } + } + + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + ingress := e.Ingress() + require_True(t, ingress != nil) + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, sfornc.Name()) + require_Equal[string](t, ingress.Account, "D") + require_Equal[string](t, ingress.Subject, "bat") + sm := e.SubjectMapping() + require_True(t, sm != nil) + require_Equal[string](t, sm.MappedTo, "baz") + simps := e.ServiceImports() + require_True(t, simps != nil) + var expectedServices int + if mainTest.allow { + expectedServices = 3 + } else { + expectedServices = 2 + } + require_Equal[int](t, len(simps), expectedServices) + for _, si := range simps { + switch si.Account { + case "C": + require_Equal[string](t, si.From, "baz") + require_Equal[string](t, si.To, "baz") + case "B": + require_Equal[string](t, si.From, "baz") + require_Equal[string](t, si.To, "bar") + case "A": + if !mainTest.allow { + t.Fatalf("Without allow_trace, we should not see service for account A") + } + require_Equal[string](t, si.From, "bar") + require_Equal[string](t, si.To, "bozo") + default: + t.Fatalf("Unexpected account name: %s", si.Account) + } + } + egress := e.Egresses() + if !mainTest.allow { + require_Equal[int](t, len(egress), 0) + } else { + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + switch eg.Kind { + case ROUTER: + require_Equal[string](t, eg.Name, sfornc2.Name()) + require_Equal[string](t, eg.Account, _EMPTY_) + case GATEWAY: + require_Equal[string](t, eg.Name, sfornc3.Name()) + require_Equal[string](t, eg.Account, _EMPTY_) + } + } + } + case ROUTER: + require_Equal[string](t, e.Server.Name, sfornc2.Name()) + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "bozo") + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == CLIENT) + require_Equal[string](t, eg.Name, "ServiceA") + require_Equal[string](t, eg.Account, _EMPTY_) + require_Equal[string](t, eg.Subscription, "*") + require_Equal[string](t, eg.Queue, "my_queue") + case GATEWAY: + require_Equal[string](t, e.Server.Name, sfornc3.Name()) + require_Equal[string](t, ingress.Account, "C") + require_Equal[string](t, ingress.Subject, "baz") + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == CLIENT) + require_Equal[string](t, eg.Name, "ServiceC") + require_Equal[string](t, eg.Account, _EMPTY_) + require_Equal[string](t, eg.Subscription, "baz") + require_Equal[string](t, eg.Queue, _EMPTY_) + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + // We should receive 3 events when allowed, a single when not. + check() + if mainTest.allow { + check() + check() + } + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + + // Make sure we properly remove the responses. + checkResp := func(an string) { + for _, s := range []*Server{sfornc, sfornc2, sfornc3} { + acc, err := s.lookupAccount(an) + require_NoError(t, err) + checkFor(t, time.Second, 15*time.Millisecond, func() error { + if n := acc.NumPendingAllResponses(); n != 0 { + return fmt.Errorf("Still %d responses pending for account %q on server %s", n, acc, s) + } + return nil + }) + } + } + for _, acc := range []string{"A", "B", "C", "D"} { + checkResp(acc) + } + }) + } + }) + } +} + +func TestMsgTraceServiceImportWithLeafNodeHub(t *testing.T) { + confHub := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + server_name: "S1" + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ { service: ">", allow_trace: true } ] + mappings = { + bar: bozo + } + } + B { + users: [{user: b, password: pwd}] + imports: [ { service: {account: "A", subject:">"} } ] + exports: [ { service: ">", allow_trace: true } ] + } + C { + users: [{user: c, password: pwd}] + exports: [ { service: ">", allow_trace: true } ] + } + D { + users: [{user: d, password: pwd}] + imports: [ + { service: {account: "B", subject:"bar"}, to: baz } + { service: {account: "C", subject:">"} } + ] + mappings = { + bat: baz + } + } + $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } + } + leafnodes { + port: -1 + } + `)) + hub, ohub := RunServerWithConfig(confHub) + defer hub.Shutdown() + + confLeaf := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:-1 + server_name: "S2" + leafnodes { + remotes [{url: "nats://d:pwd@127.0.0.1:%d"}] + } + `, ohub.LeafNode.Port))) + leaf, _ := RunServerWithConfig(confLeaf) + defer leaf.Shutdown() + + checkLeafNodeConnectedCount(t, hub, 1) + checkLeafNodeConnectedCount(t, leaf, 1) + + nc2 := natsConnect(t, hub.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("ServiceA")) + defer nc2.Close() + recv := int32(0) + natsQueueSub(t, nc2, "*", "my_queue", func(m *nats.Msg) { + atomic.AddInt32(&recv, 1) + m.Respond(m.Data) + }) + natsFlush(t, nc2) + + nc3 := natsConnect(t, hub.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("ServiceC")) + defer nc3.Close() + natsSub(t, nc3, "baz", func(m *nats.Msg) { + atomic.AddInt32(&recv, 1) + m.Respond(m.Data) + }) + natsFlush(t, nc3) + + nc := natsConnect(t, leaf.ClientURL(), nats.Name("Requestor")) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + sub := natsSubSync(t, nc, "my.service.response.inbox") + + checkSubInterest(t, leaf, globalAccountName, "bat", time.Second) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("bat") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + if !test.deliverMsg { + msg.Data = []byte("request1") + } else { + msg.Data = []byte("request2") + } + msg.Reply = sub.Subject + + err := nc.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + for i := 0; i < 2; i++ { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "request2") + } + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); msg != nil || err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got msg=%v err=%v", msg, err) + } + if !test.deliverMsg { + // Just to make sure that message was not delivered to service + // responders, wait a bit and check the recv value. + time.Sleep(50 * time.Millisecond) + if n := atomic.LoadInt32(&recv); n != 0 { + t.Fatalf("Expected no message to be received, but service callback fired %d times", n) + } + } + + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + ingress := e.Ingress() + require_True(t, ingress != nil) + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, "S2") + require_Equal[string](t, ingress.Account, globalAccountName) + require_Equal[string](t, ingress.Subject, "bat") + require_True(t, e.SubjectMapping() == nil) + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == LEAF) + require_Equal[string](t, eg.Name, "S1") + require_Equal[string](t, eg.Account, _EMPTY_) + case LEAF: + require_Equal[string](t, e.Server.Name, hub.Name()) + require_Equal[string](t, ingress.Name, leaf.Name()) + require_Equal[string](t, ingress.Account, "D") + require_Equal[string](t, ingress.Subject, "bat") + sm := e.SubjectMapping() + require_True(t, sm != nil) + require_Equal[string](t, sm.MappedTo, "baz") + simps := e.ServiceImports() + require_True(t, simps != nil) + require_Equal[int](t, len(simps), 3) + for _, si := range simps { + switch si.Account { + case "C": + require_Equal[string](t, si.From, "baz") + require_Equal[string](t, si.To, "baz") + case "B": + require_Equal[string](t, si.From, "baz") + require_Equal[string](t, si.To, "bar") + case "A": + require_Equal[string](t, si.From, "bar") + require_Equal[string](t, si.To, "bozo") + default: + t.Fatalf("Unexpected account name: %s", si.Account) + } + } + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + require_True(t, eg.Kind == CLIENT) + switch eg.Account { + case "C": + require_Equal[string](t, eg.Name, "ServiceC") + require_Equal[string](t, eg.Subscription, "baz") + require_Equal[string](t, eg.Queue, _EMPTY_) + case "A": + require_Equal[string](t, eg.Name, "ServiceA") + require_Equal[string](t, eg.Subscription, "*") + require_Equal[string](t, eg.Queue, "my_queue") + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + // We should receive 2 events. + check() + check() + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + + // Make sure we properly remove the responses. + checkResp := func(an string) { + acc, err := hub.lookupAccount(an) + require_NoError(t, err) + checkFor(t, time.Second, 15*time.Millisecond, func() error { + if n := acc.NumPendingAllResponses(); n != 0 { + return fmt.Errorf("Still %d responses for account %q pending on %s", n, an, hub) + } + return nil + }) + } + for _, acc := range []string{"A", "B", "C", "D"} { + checkResp(acc) + } + }) + } +} + +func TestMsgTraceServiceImportWithLeafNodeLeaf(t *testing.T) { + confHub := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + server_name: "S1" + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ { service: "bar", allow_trace: true } ] + } + B { + users: [{user: b, password: pwd}] + imports: [{ service: {account: "A", subject:"bar"}, to: baz }] + } + $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } + } + leafnodes { + port: -1 + } + `)) + hub, ohub := RunServerWithConfig(confHub) + defer hub.Shutdown() + + confLeaf := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:-1 + server_name: "S2" + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ { service: "bar"} ] + } + B { users: [{user: b, password: pwd}] } + $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } + } + leafnodes { + remotes [ + { + url: "nats://a:pwd@127.0.0.1:%d" + account: A + } + { + url: "nats://b:pwd@127.0.0.1:%d" + account: B + } + ] + } + `, ohub.LeafNode.Port, ohub.LeafNode.Port))) + leaf, _ := RunServerWithConfig(confLeaf) + defer leaf.Shutdown() + + checkLeafNodeConnectedCount(t, hub, 2) + checkLeafNodeConnectedCount(t, leaf, 2) + + nc2 := natsConnect(t, leaf.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("ServiceA")) + defer nc2.Close() + recv := int32(0) + natsQueueSub(t, nc2, "*", "my_queue", func(m *nats.Msg) { + atomic.AddInt32(&recv, 1) + m.Respond(m.Data) + }) + natsFlush(t, nc2) + + nc := natsConnect(t, hub.ClientURL(), nats.UserInfo("b", "pwd"), nats.Name("Requestor")) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + sub := natsSubSync(t, nc, "my.service.response.inbox") + + // Check that hub has a subscription interest on "baz" + checkSubInterest(t, hub, "A", "baz", time.Second) + // And check that the leaf has the sub interest on the trace subject + checkSubInterest(t, leaf, "B", traceSub.Subject, time.Second) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("baz") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + if !test.deliverMsg { + msg.Data = []byte("request1") + } else { + msg.Data = []byte("request2") + } + msg.Reply = sub.Subject + + err := nc.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + appMsg := natsNexMsg(t, sub, time.Second) + require_Equal[string](t, string(appMsg.Data), "request2") + } + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); msg != nil || err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got msg=%v err=%v", msg, err) + } + if !test.deliverMsg { + // Just to make sure that message was not delivered to service + // responders, wait a bit and check the recv value. + time.Sleep(50 * time.Millisecond) + if n := atomic.LoadInt32(&recv); n != 0 { + t.Fatalf("Expected no message to be received, but service callback fired %d times", n) + } + } + + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + ingress := e.Ingress() + require_True(t, ingress != nil) + + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, "S1") + require_Equal[string](t, ingress.Name, "Requestor") + require_Equal[string](t, ingress.Account, "B") + require_Equal[string](t, ingress.Subject, "baz") + require_True(t, e.SubjectMapping() == nil) + simps := e.ServiceImports() + require_True(t, simps != nil) + require_Equal[int](t, len(simps), 1) + si := simps[0] + require_Equal[string](t, si.Account, "A") + require_Equal[string](t, si.From, "baz") + require_Equal[string](t, si.To, "bar") + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == LEAF) + require_Equal[string](t, eg.Name, "S2") + require_Equal[string](t, eg.Account, "A") + require_Equal[string](t, eg.Subscription, _EMPTY_) + case LEAF: + require_Equal[string](t, e.Server.Name, leaf.Name()) + require_Equal[string](t, ingress.Name, hub.Name()) + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "bar") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.ServiceImports() == nil) + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == CLIENT) + require_Equal[string](t, eg.Name, "ServiceA") + require_Equal[string](t, eg.Subscription, "*") + require_Equal[string](t, eg.Queue, "my_queue") + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + // We should receive 2 events. + check() + check() + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + + // Make sure we properly remove the responses. + checkResp := func(an string) { + acc, err := leaf.lookupAccount(an) + require_NoError(t, err) + checkFor(t, time.Second, 15*time.Millisecond, func() error { + if n := acc.NumPendingAllResponses(); n != 0 { + return fmt.Errorf("Still %d responses for account %q pending on %s", n, an, leaf) + } + return nil + }) + } + for _, acc := range []string{"A", "B"} { + checkResp(acc) + } + }) + } +} + +func TestMsgTraceStreamExport(t *testing.T) { + for _, mainTest := range []struct { + name string + allow bool + }{ + {"allowed", true}, + {"not allowed", false}, + } { + t.Run(mainTest.name, func(t *testing.T) { + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:-1 + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ + { stream: "info.*.*.>"} + ] + } + B { + users: [{user: b, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "B.info.$2.$1.>", allow_trace: %v } ] + } + C { + users: [{user: c, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "C.info.$1.$2.>", allow_trace: %v } ] + } + } + `, mainTest.allow, mainTest.allow))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("Tracer")) + defer nc.Close() + traceSub := natsSubSync(t, nc, "my.trace.subj") + + nc2 := natsConnect(t, s.ClientURL(), nats.UserInfo("b", "pwd"), nats.Name("sub1")) + defer nc2.Close() + sub1 := natsSubSync(t, nc2, "B.info.*.*.>") + natsFlush(t, nc2) + + nc3 := natsConnect(t, s.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("sub2")) + defer nc3.Close() + sub2 := natsQueueSubSync(t, nc3, "C.info.>", "my_queue") + natsFlush(t, nc3) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("info.11.22.bar") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello") + + err := nc.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + appMsg := natsNexMsg(t, sub1, time.Second) + require_Equal[string](t, appMsg.Subject, "B.info.22.11.bar") + appMsg = natsNexMsg(t, sub2, time.Second) + require_Equal[string](t, appMsg.Subject, "C.info.11.22.bar") + } + // Check that no (more) messages are received. + for _, sub := range []*nats.Subscription{sub1, sub2} { + if msg, err := sub.NextMsg(100 * time.Millisecond); msg != nil || err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got msg=%v err=%v", msg, err) + } + } + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + require_Equal[string](t, e.Server.Name, s.Name()) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "info.11.22.bar") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.ServiceImports() == nil) + stexps := e.StreamExports() + require_True(t, stexps != nil) + require_Equal[int](t, len(stexps), 2) + for _, se := range stexps { + require_True(t, se.Timestamp != time.Time{}) + switch se.Account { + case "B": + require_Equal[string](t, se.To, "B.info.22.11.bar") + case "C": + require_Equal[string](t, se.To, "C.info.11.22.bar") + default: + t.Fatalf("Unexpected stream export: %+v", se) + } + } + egress := e.Egresses() + if mainTest.allow { + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + require_True(t, eg.Kind == CLIENT) + switch eg.Account { + case "B": + require_Equal[string](t, eg.Name, "sub1") + require_Equal[string](t, eg.Subscription, "info.*.*.>") + require_Equal[string](t, eg.Queue, _EMPTY_) + case "C": + require_Equal[string](t, eg.Name, "sub2") + require_Equal[string](t, eg.Subscription, "info.*.*.>") + require_Equal[string](t, eg.Queue, "my_queue") + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + } else { + require_Equal[int](t, len(egress), 0) + } + }) + } + }) + } +} + +func TestMsgTraceStreamExportWithSuperCluster(t *testing.T) { + for _, mainTest := range []struct { + name string + allowStr string + allow bool + }{ + {"allowed", "true", true}, + {"not allowed", "false", false}, + } { + t.Run(mainTest.name, func(t *testing.T) { + tmpl := ` + listen: 127.0.0.1:-1 + server_name: %s + jetstream: {max_mem_store: 256MB, max_file_store: 2GB, store_dir: '%s'} + + cluster { + name: %s + listen: 127.0.0.1:%d + routes = [%s] + } + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ + { stream: "info.*.*.>"} + ] + } + B { + users: [{user: b, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "B.info.$2.$1.>", allow_trace: ` + mainTest.allowStr + ` } ] + } + C { + users: [{user: c, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "C.info.$1.$2.>", allow_trace: ` + mainTest.allowStr + ` } ] + } + $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } + } + ` + sc := createJetStreamSuperClusterWithTemplate(t, tmpl, 2, 2) + defer sc.shutdown() + + sfornc := sc.clusters[0].servers[0] + nc := natsConnect(t, sfornc.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("Tracer")) + defer nc.Close() + traceSub := natsSubSync(t, nc, "my.trace.subj") + + sfornc2 := sc.clusters[0].servers[1] + nc2 := natsConnect(t, sfornc2.ClientURL(), nats.UserInfo("b", "pwd"), nats.Name("sub1")) + defer nc2.Close() + sub1 := natsSubSync(t, nc2, "B.info.*.*.>") + natsFlush(t, nc2) + checkSubInterest(t, sfornc2, "A", traceSub.Subject, time.Second) + + sfornc3 := sc.clusters[1].servers[0] + nc3 := natsConnect(t, sfornc3.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("sub2")) + defer nc3.Close() + sub2 := natsQueueSubSync(t, nc3, "C.info.>", "my_queue") + natsFlush(t, nc3) + + checkSubInterest(t, sfornc, "A", "info.1.2.3.4", time.Second) + for _, s := range sc.clusters[0].servers { + checkForRegisteredQSubInterest(t, s, "C2", "A", "info.1.2.3", 1, time.Second) + } + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("info.11.22.bar") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello") + + err := nc.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + appMsg := natsNexMsg(t, sub1, time.Second) + require_Equal[string](t, appMsg.Subject, "B.info.22.11.bar") + appMsg = natsNexMsg(t, sub2, time.Second) + require_Equal[string](t, appMsg.Subject, "C.info.11.22.bar") + } + // Check that no (more) messages are received. + for _, sub := range []*nats.Subscription{sub1, sub2} { + if msg, err := sub.NextMsg(100 * time.Millisecond); msg != nil || err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got msg=%v err=%v", msg, err) + } + } + + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + ingress := e.Ingress() + require_True(t, ingress != nil) + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, sfornc.Name()) + require_Equal[string](t, ingress.Name, "Tracer") + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "info.11.22.bar") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.ServiceImports() == nil) + require_True(t, e.StreamExports() == nil) + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + switch eg.Kind { + case ROUTER: + require_Equal[string](t, eg.Name, sfornc2.Name()) + require_Equal[string](t, eg.Account, _EMPTY_) + case GATEWAY: + require_Equal[string](t, eg.Name, sfornc3.Name()) + require_Equal[string](t, eg.Account, _EMPTY_) + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + case ROUTER: + require_Equal[string](t, e.Server.Name, sfornc2.Name()) + require_Equal[string](t, ingress.Name, sfornc.Name()) + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "info.11.22.bar") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.ServiceImports() == nil) + stexps := e.StreamExports() + require_True(t, stexps != nil) + require_Equal[int](t, len(stexps), 1) + se := stexps[0] + require_Equal[string](t, se.Account, "B") + require_Equal[string](t, se.To, "B.info.22.11.bar") + egress := e.Egresses() + if mainTest.allow { + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == CLIENT) + require_Equal[string](t, eg.Name, "sub1") + require_Equal[string](t, eg.Subscription, "info.*.*.>") + require_Equal[string](t, eg.Queue, _EMPTY_) + } else { + require_Equal[int](t, len(egress), 0) + } + case GATEWAY: + require_Equal[string](t, e.Server.Name, sfornc3.Name()) + require_Equal[string](t, ingress.Name, sfornc.Name()) + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "info.11.22.bar") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.ServiceImports() == nil) + stexps := e.StreamExports() + require_True(t, stexps != nil) + require_Equal[int](t, len(stexps), 1) + se := stexps[0] + require_Equal[string](t, se.Account, "C") + require_Equal[string](t, se.To, "C.info.11.22.bar") + egress := e.Egresses() + if mainTest.allow { + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == CLIENT) + require_Equal[string](t, eg.Name, "sub2") + require_Equal[string](t, eg.Subscription, "info.*.*.>") + require_Equal[string](t, eg.Queue, "my_queue") + } else { + require_Equal[int](t, len(egress), 0) + } + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + // We expect 3 events + check() + check() + check() + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + }) + } + }) + } +} + +func TestMsgTraceStreamExportWithLeafNode_Hub(t *testing.T) { + confHub := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + server_name: "S1" + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ + { stream: "info.*.*.>"} + ] + } + B { + users: [{user: b, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "B.info.$2.$1.>", allow_trace: true } ] + } + C { + users: [{user: c, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "C.info.$1.$2.>", allow_trace: true } ] + } + } + leafnodes { + port: -1 + } + `)) + hub, ohub := RunServerWithConfig(confHub) + defer hub.Shutdown() + + confLeaf := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:-1 + server_name: "S2" + accounts { + LEAF { users: [{user: leaf, password: pwd}] } + } + leafnodes { + remotes [ + { url: "nats://a:pwd@127.0.0.1:%d", account: "LEAF" } + ] + } + `, ohub.LeafNode.Port))) + leaf, _ := RunServerWithConfig(confLeaf) + defer leaf.Shutdown() + + checkLeafNodeConnectedCount(t, hub, 1) + checkLeafNodeConnectedCount(t, leaf, 1) + + nc := natsConnect(t, leaf.ClientURL(), nats.UserInfo("leaf", "pwd"), nats.Name("Tracer")) + defer nc.Close() + traceSub := natsSubSync(t, nc, "my.trace.subj") + + checkSubInterest(t, hub, "A", traceSub.Subject, time.Second) + + nc2 := natsConnect(t, hub.ClientURL(), nats.UserInfo("b", "pwd"), nats.Name("sub1")) + defer nc2.Close() + sub1 := natsSubSync(t, nc2, "B.info.*.*.>") + natsFlush(t, nc2) + + nc3 := natsConnect(t, hub.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("sub2")) + defer nc3.Close() + sub2 := natsQueueSubSync(t, nc3, "C.info.>", "my_queue") + natsFlush(t, nc3) + + acc, err := leaf.LookupAccount("LEAF") + require_NoError(t, err) + checkFor(t, time.Second, 50*time.Millisecond, func() error { + acc.mu.RLock() + sl := acc.sl + acc.mu.RUnlock() + r := sl.Match("info.1.2.3") + ok := len(r.psubs) > 0 + if ok && (len(r.qsubs) == 0 || len(r.qsubs[0]) == 0) { + ok = false + } + if !ok { + return fmt.Errorf("Subscription interest not yet propagated") + } + return nil + }) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + + {"just trace", false}, + {"deliver msg", true}, + } { + + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("info.11.22.bar") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello") + + err := nc.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + appMsg := natsNexMsg(t, sub1, time.Second) + require_Equal[string](t, appMsg.Subject, "B.info.22.11.bar") + appMsg = natsNexMsg(t, sub2, time.Second) + require_Equal[string](t, appMsg.Subject, "C.info.11.22.bar") + } + // Check that no (more) messages are received. + for _, sub := range []*nats.Subscription{sub1, sub2} { + if msg, err := sub.NextMsg(100 * time.Millisecond); msg != nil || err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got msg=%v err=%v", msg, err) + } + } + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + ingress := e.Ingress() + require_True(t, ingress != nil) + + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, leaf.Name()) + require_Equal[string](t, ingress.Name, "Tracer") + require_Equal[string](t, ingress.Account, "LEAF") + require_Equal[string](t, ingress.Subject, "info.11.22.bar") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.ServiceImports() == nil) + require_True(t, e.StreamExports() == nil) + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == LEAF) + require_Equal[string](t, eg.Name, hub.Name()) + require_Equal[string](t, eg.Account, _EMPTY_) + require_Equal[string](t, eg.Subscription, _EMPTY_) + require_Equal[string](t, eg.Queue, _EMPTY_) + case LEAF: + require_Equal[string](t, e.Server.Name, hub.Name()) + require_Equal[string](t, ingress.Name, leaf.Name()) + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "info.11.22.bar") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.ServiceImports() == nil) + stexps := e.StreamExports() + require_True(t, stexps != nil) + require_Equal[int](t, len(stexps), 2) + for _, se := range stexps { + switch se.Account { + case "B": + require_Equal[string](t, se.To, "B.info.22.11.bar") + case "C": + require_Equal[string](t, se.To, "C.info.11.22.bar") + default: + t.Fatalf("Unexpected stream export: %+v", se) + } + } + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + require_True(t, eg.Kind == CLIENT) + switch eg.Account { + case "B": + require_Equal[string](t, eg.Name, "sub1") + require_Equal[string](t, eg.Subscription, "info.*.*.>") + require_Equal[string](t, eg.Queue, _EMPTY_) + case "C": + require_Equal[string](t, eg.Name, "sub2") + require_Equal[string](t, eg.Subscription, "info.*.*.>") + require_Equal[string](t, eg.Queue, "my_queue") + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + // We expect 2 events + check() + check() + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + }) + } +} + +func TestMsgTraceStreamExportWithLeafNode_Leaf(t *testing.T) { + confHub := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + server_name: "S1" + accounts { + HUB { users: [{user: hub, password: pwd}] } + } + leafnodes { + port: -1 + } + `)) + hub, ohub := RunServerWithConfig(confHub) + defer hub.Shutdown() + + confLeaf := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:-1 + server_name: "S2" + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ + { stream: "info.*.*.>"} + ] + } + B { + users: [{user: b, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "B.info.$2.$1.>", allow_trace: true } ] + } + C { + users: [{user: c, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "C.info.$1.$2.>", allow_trace: true } ] + } + } + leafnodes { + remotes [ + { url: "nats://hub:pwd@127.0.0.1:%d", account: "A" } + ] + } + `, ohub.LeafNode.Port))) + leaf, _ := RunServerWithConfig(confLeaf) + defer leaf.Shutdown() + + checkLeafNodeConnectedCount(t, hub, 1) + checkLeafNodeConnectedCount(t, leaf, 1) + + nc := natsConnect(t, hub.ClientURL(), nats.UserInfo("hub", "pwd"), nats.Name("Tracer")) + defer nc.Close() + traceSub := natsSubSync(t, nc, "my.trace.subj") + + checkSubInterest(t, leaf, "A", traceSub.Subject, time.Second) + + nc2 := natsConnect(t, leaf.ClientURL(), nats.UserInfo("b", "pwd"), nats.Name("sub1")) + defer nc2.Close() + sub1 := natsSubSync(t, nc2, "B.info.*.*.>") + natsFlush(t, nc2) + + nc3 := natsConnect(t, leaf.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("sub2")) + defer nc3.Close() + sub2 := natsQueueSubSync(t, nc3, "C.info.>", "my_queue") + natsFlush(t, nc3) + + acc, err := hub.LookupAccount("HUB") + require_NoError(t, err) + checkFor(t, time.Second, 50*time.Millisecond, func() error { + acc.mu.RLock() + sl := acc.sl + acc.mu.RUnlock() + r := sl.Match("info.1.2.3") + ok := len(r.psubs) > 0 + if ok && (len(r.qsubs) == 0 || len(r.qsubs[0]) == 0) { + ok = false + } + if !ok { + return fmt.Errorf("Subscription interest not yet propagated") + } + return nil + }) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + + {"just trace", false}, + {"deliver msg", true}, + } { + + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("info.11.22.bar") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = []byte("hello") + + err := nc.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + appMsg := natsNexMsg(t, sub1, time.Second) + require_Equal[string](t, appMsg.Subject, "B.info.22.11.bar") + appMsg = natsNexMsg(t, sub2, time.Second) + require_Equal[string](t, appMsg.Subject, "C.info.11.22.bar") + } + // Check that no (more) messages are received. + for _, sub := range []*nats.Subscription{sub1, sub2} { + if msg, err := sub.NextMsg(100 * time.Millisecond); msg != nil || err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got msg=%v err=%v", msg, err) + } + } + check := func() { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + ingress := e.Ingress() + require_True(t, ingress != nil) + + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, hub.Name()) + require_Equal[string](t, ingress.Name, "Tracer") + require_Equal[string](t, ingress.Account, "HUB") + require_Equal[string](t, ingress.Subject, "info.11.22.bar") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.ServiceImports() == nil) + require_True(t, e.StreamExports() == nil) + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + eg := egress[0] + require_True(t, eg.Kind == LEAF) + require_Equal[string](t, eg.Name, leaf.Name()) + require_Equal[string](t, eg.Account, _EMPTY_) + require_Equal[string](t, eg.Subscription, _EMPTY_) + require_Equal[string](t, eg.Queue, _EMPTY_) + case LEAF: + require_Equal[string](t, e.Server.Name, leaf.Name()) + require_Equal[string](t, ingress.Name, hub.Name()) + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "info.11.22.bar") + require_True(t, e.SubjectMapping() == nil) + require_True(t, e.ServiceImports() == nil) + stexps := e.StreamExports() + require_True(t, stexps != nil) + require_Equal[int](t, len(stexps), 2) + for _, se := range stexps { + switch se.Account { + case "B": + require_Equal[string](t, se.To, "B.info.22.11.bar") + case "C": + require_Equal[string](t, se.To, "C.info.11.22.bar") + default: + t.Fatalf("Unexpected stream export: %+v", se) + } + } + egress := e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + require_True(t, eg.Kind == CLIENT) + switch eg.Account { + case "B": + require_Equal[string](t, eg.Name, "sub1") + require_Equal[string](t, eg.Subscription, "info.*.*.>") + require_Equal[string](t, eg.Queue, _EMPTY_) + case "C": + require_Equal[string](t, eg.Name, "sub2") + require_Equal[string](t, eg.Subscription, "info.*.*.>") + require_Equal[string](t, eg.Queue, "my_queue") + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + // We expect 2 events + check() + check() + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + }) + } +} + +func TestMsgTraceJetStream(t *testing.T) { + opts := DefaultTestOptions + opts.Port = -1 + opts.JetStream = true + opts.JetStreamMaxMemory = 270 + opts.StoreDir = t.TempDir() + s := RunServer(&opts) + defer s.Shutdown() + + nc, js := jsClientConnect(t, s) + defer nc.Close() + + cfg := &nats.StreamConfig{ + Name: "TEST", + Storage: nats.MemoryStorage, + Subjects: []string{"foo"}, + Replicas: 1, + AllowRollup: true, + SubjectTransform: &nats.SubjectTransformConfig{ + Source: "foo", + Destination: "bar", + }, + } + _, err := js.AddStream(cfg) + require_NoError(t, err) + + nct := natsConnect(t, s.ClientURL(), nats.Name("Tracer")) + defer nct.Close() + + traceSub := natsSubSync(t, nct, "my.trace.subj") + natsFlush(t, nct) + + msg := nats.NewMsg("foo") + msg.Header.Set(JSMsgId, "MyId") + msg.Data = make([]byte, 50) + _, err = js.PublishMsg(msg) + require_NoError(t, err) + + checkStream := func(t *testing.T, expected int) { + t.Helper() + checkFor(t, time.Second, 15*time.Millisecond, func() error { + si, err := js.StreamInfo("TEST") + if err != nil { + return err + } + if n := si.State.Msgs; int(n) != expected { + return fmt.Errorf("Expected %d messages, got %v", expected, n) + } + return nil + }) + } + checkStream(t, 1) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Data = make([]byte, 50) + err = nct.PublishMsg(msg) + require_NoError(t, err) + + // Wait a bit and check if message should be in the stream or not. + time.Sleep(50 * time.Millisecond) + if test.deliverMsg { + checkStream(t, 2) + } else { + checkStream(t, 1) + } + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, ingress.Name, "Tracer") + require_Equal[int](t, len(e.Egresses()), 0) + js := e.JetStream() + require_True(t, js != nil) + require_True(t, js.Timestamp != time.Time{}) + require_Equal[string](t, js.Stream, "TEST") + require_Equal[string](t, js.Subject, "bar") + require_False(t, js.NoInterest) + require_Equal[string](t, js.Error, _EMPTY_) + }) + } + + jst, err := nct.JetStream() + require_NoError(t, err) + + mset, err := s.globalAccount().lookupStream("TEST") + require_NoError(t, err) + + // Now we will not ask for delivery and use headers that will fail checks + // and make sure that message is not added, that the stream's clfs is not + // increased, and that the JS trace shows the error. + newMsg := func() *nats.Msg { + msg = nats.NewMsg("foo") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + msg.Header.Set(MsgTraceOnly, "true") + msg.Data = []byte("hello") + return msg + } + + msgCount := 2 + for _, test := range []struct { + name string + headerName string + headerVal string + expectedErr string + special int + }{ + {"unexpected stream name", JSExpectedStream, "WRONG", "expected stream does not match", 0}, + {"duplicate id", JSMsgId, "MyId", "duplicate", 0}, + {"last seq by subject mismatch", JSExpectedLastSubjSeq, "10", "last sequence by subject mismatch", 0}, + {"last seq mismatch", JSExpectedLastSeq, "10", "last sequence mismatch", 0}, + {"last msgid mismatch", JSExpectedLastMsgId, "MyId3", "last msgid mismatch", 0}, + {"invalid rollup command", JSMsgRollup, "wrong", "rollup value invalid: \"wrong\"", 0}, + {"rollup not permitted", JSMsgRollup, JSMsgRollupSubject, "rollup not permitted", 1}, + {"max msg size", _EMPTY_, _EMPTY_, ErrMaxPayload.Error(), 2}, + {"normal message ok", _EMPTY_, _EMPTY_, _EMPTY_, 3}, + {"insufficient resources", _EMPTY_, _EMPTY_, NewJSInsufficientResourcesError().Error(), 0}, + {"stream sealed", _EMPTY_, _EMPTY_, NewJSStreamSealedError().Error(), 4}, + } { + t.Run(test.name, func(t *testing.T) { + msg = newMsg() + if test.headerName != _EMPTY_ { + msg.Header.Set(test.headerName, test.headerVal) + } + switch test.special { + case 1: + // Update stream to prevent rollups, and set a max size. + cfg.AllowRollup = false + cfg.MaxMsgSize = 100 + _, err = js.UpdateStream(cfg) + require_NoError(t, err) + case 2: + msg.Data = make([]byte, 200) + case 3: + pa, err := jst.Publish("foo", make([]byte, 100)) + require_NoError(t, err) + msgCount++ + checkStream(t, msgCount) + require_Equal[uint64](t, pa.Sequence, 3) + return + case 4: + cfg.Sealed = true + _, err = js.UpdateStream(cfg) + require_NoError(t, err) + default: + } + jst.PublishMsg(msg) + + // Message count should not have increased and stay at 2. + checkStream(t, msgCount) + // Check that clfs does not increase + mset.mu.RLock() + clfs := mset.getCLFS() + mset.mu.RUnlock() + if clfs != 0 { + t.Fatalf("Stream's clfs was expected to be 0, is %d", clfs) + } + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, ingress.Name, "Tracer") + require_Equal[int](t, len(e.Egresses()), 0) + js := e.JetStream() + require_True(t, js != nil) + require_Equal[string](t, js.Stream, "TEST") + require_Equal[string](t, js.Subject, _EMPTY_) + require_False(t, js.NoInterest) + if et := js.Error; !strings.Contains(et, test.expectedErr) { + t.Fatalf("Expected JS error to contain %q, got %q", test.expectedErr, et) + } + }) + } + + // Create a stream with interest retention policy + _, err = js.AddStream(&nats.StreamConfig{ + Name: "NO_INTEREST", + Subjects: []string{"baz"}, + Retention: nats.InterestPolicy, + }) + require_NoError(t, err) + msg = nats.NewMsg("baz") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + msg.Header.Set(MsgTraceOnly, "true") + msg.Data = []byte("hello") + err = nct.PublishMsg(msg) + require_NoError(t, err) + + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + require_Equal[string](t, e.Server.Name, s.Name()) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, ingress.Name, "Tracer") + require_Equal[int](t, len(e.Egresses()), 0) + ejs := e.JetStream() + require_True(t, js != nil) + require_Equal[string](t, ejs.Stream, "NO_INTEREST") + require_Equal[string](t, ejs.Subject, "baz") + require_True(t, ejs.NoInterest) + require_Equal[string](t, ejs.Error, _EMPTY_) +} + +func TestMsgTraceJetStreamWithSuperCluster(t *testing.T) { + sc := createJetStreamSuperCluster(t, 3, 2) + defer sc.shutdown() + + c1 := sc.clusters[0] + c2 := sc.clusters[1] + nc, js := jsClientConnect(t, c1.randomServer()) + defer nc.Close() + + checkStream := func(t *testing.T, stream string, expected int) { + t.Helper() + checkFor(t, time.Second, 15*time.Millisecond, func() error { + si, err := js.StreamInfo(stream) + if err != nil { + return err + } + if n := si.State.Msgs; int(n) != expected { + return fmt.Errorf("Expected %d messages for stream %q, got %v", expected, stream, n) + } + return nil + }) + } + + for mainIter, mainTest := range []struct { + name string + stream string + }{ + {"from stream leader", "TEST1"}, + {"from non stream leader", "TEST2"}, + {"from other cluster", "TEST3"}, + } { + t.Run(mainTest.name, func(t *testing.T) { + cfg := &nats.StreamConfig{ + Name: mainTest.stream, + Replicas: 3, + AllowRollup: true, + } + _, err := js.AddStream(cfg) + require_NoError(t, err) + sc.waitOnStreamLeader(globalAccountName, mainTest.stream) + + // The streams are created from c1 cluster. + slSrv := c1.streamLeader(globalAccountName, mainTest.stream) + + // Store some messages + payload := make([]byte, 50) + for i := 0; i < 5; i++ { + _, err = js.Publish(mainTest.stream, payload) + require_NoError(t, err) + } + + // We will connect the app that sends the trace message to a server + // that is either the stream leader, a random server in c1, or + // a server in c2 (to go through a GW). + var s *Server + switch mainIter { + case 0: + s = slSrv + case 1: + s = c1.randomNonStreamLeader(globalAccountName, mainTest.stream) + case 2: + s = c2.randomServer() + } + + nct := natsConnect(t, s.ClientURL(), nats.Name("Tracer")) + defer nct.Close() + + traceSub := natsSubSync(t, nct, "my.trace.subj") + natsFlush(t, nct) + + for _, test := range []struct { + name string + deliverMsg bool + }{ + {"just trace", false}, + {"deliver msg", true}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg(mainTest.stream) + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + if !test.deliverMsg { + msg.Header.Set(MsgTraceOnly, "true") + } + msg.Header.Set(JSMsgId, "MyId") + msg.Data = payload + err = nct.PublishMsg(msg) + require_NoError(t, err) + + if test.deliverMsg { + checkStream(t, mainTest.stream, 6) + } else { + checkStream(t, mainTest.stream, 5) + } + + check := func() bool { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + checkJS := func() { + t.Helper() + js := e.JetStream() + require_True(t, js != nil) + require_Equal[string](t, js.Stream, mainTest.stream) + require_Equal[string](t, js.Subject, mainTest.stream) + require_False(t, js.NoInterest) + require_Equal[string](t, js.Error, _EMPTY_) + } + + ingress := e.Ingress() + require_True(t, ingress != nil) + switch mainIter { + case 0: + require_Equal[string](t, e.Server.Name, s.Name()) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, ingress.Name, "Tracer") + require_Equal[int](t, len(e.Egresses()), 0) + checkJS() + case 1: + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, s.Name()) + require_Equal[string](t, ingress.Name, "Tracer") + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + ci := egress[0] + require_True(t, ci.Kind == ROUTER) + require_Equal[string](t, ci.Name, slSrv.Name()) + case ROUTER: + require_Equal[string](t, e.Server.Name, slSrv.Name()) + require_Equal[int](t, len(e.Egresses()), 0) + checkJS() + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + case 2: + switch ingress.Kind { + case CLIENT: + require_Equal[string](t, e.Server.Name, s.Name()) + require_Equal[string](t, ingress.Name, "Tracer") + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + ci := egress[0] + require_True(t, ci.Kind == GATEWAY) + // It could have gone to any server in the C1 cluster. + // If it is not the stream leader, it should be + // routed to it. + case GATEWAY: + require_Equal[string](t, ingress.Name, s.Name()) + // If the server that emitted this event is the + // stream leader, then we should have the stream, + // otherwise, it should be routed. + if e.Server.Name == slSrv.Name() { + require_Equal[int](t, len(e.Egresses()), 0) + checkJS() + } else { + egress := e.Egresses() + require_Equal[int](t, len(egress), 1) + ci := egress[0] + require_True(t, ci.Kind == ROUTER) + require_Equal[string](t, ci.Name, slSrv.Name()) + return true + } + case ROUTER: + require_Equal[string](t, e.Server.Name, slSrv.Name()) + require_Equal[int](t, len(e.Egresses()), 0) + checkJS() + default: + t.Fatalf("Unexpected ingress: %+v", ingress) + } + } + return false + } + check() + if mainIter > 0 { + if check() { + check() + } + } + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + }) + } + + jst, err := nct.JetStream() + require_NoError(t, err) + + newMsg := func() *nats.Msg { + msg := nats.NewMsg(mainTest.stream) + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + msg.Header.Set(MsgTraceOnly, "true") + msg.Data = []byte("hello") + return msg + } + + msgCount := 6 + for _, subtest := range []struct { + name string + headerName string + headerVal string + expectedErr string + special int + }{ + {"unexpected stream name", JSExpectedStream, "WRONG", "expected stream does not match", 0}, + {"duplicate id", JSMsgId, "MyId", "duplicate", 0}, + {"last seq by subject mismatch", JSExpectedLastSubjSeq, "3", "last sequence by subject mismatch", 0}, + {"last seq mismatch", JSExpectedLastSeq, "10", "last sequence mismatch", 0}, + {"last msgid mismatch", JSExpectedLastMsgId, "MyId3", "last msgid mismatch", 0}, + {"invalid rollup command", JSMsgRollup, "wrong", "rollup value invalid: \"wrong\"", 0}, + {"rollup not permitted", JSMsgRollup, JSMsgRollupSubject, "rollup not permitted", 1}, + {"max msg size", _EMPTY_, _EMPTY_, ErrMaxPayload.Error(), 2}, + {"new message ok", _EMPTY_, _EMPTY_, _EMPTY_, 3}, + {"stream sealed", _EMPTY_, _EMPTY_, NewJSStreamSealedError().Error(), 4}, + } { + t.Run(subtest.name, func(t *testing.T) { + msg := newMsg() + if subtest.headerName != _EMPTY_ { + msg.Header.Set(subtest.headerName, subtest.headerVal) + } + switch subtest.special { + case 1: + // Update stream to prevent rollups, and set a max size. + cfg.AllowRollup = false + cfg.MaxMsgSize = 100 + _, err = js.UpdateStream(cfg) + require_NoError(t, err) + case 2: + msg.Data = make([]byte, 200) + case 3: + pa, err := jst.Publish(mainTest.stream, []byte("hello")) + require_NoError(t, err) + msgCount++ + checkStream(t, mainTest.stream, msgCount) + require_Equal[uint64](t, pa.Sequence, 7) + return + case 4: + cfg.Sealed = true + _, err = js.UpdateStream(cfg) + require_NoError(t, err) + default: + } + jst.PublishMsg(msg) + checkStream(t, mainTest.stream, msgCount) + checkJSTrace := func() bool { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + checkJS := func() { + t.Helper() + js := e.JetStream() + require_True(t, e.JetStream() != nil) + require_Equal[string](t, js.Stream, mainTest.stream) + require_Equal[string](t, js.Subject, _EMPTY_) + require_False(t, js.NoInterest) + if et := js.Error; !strings.Contains(et, subtest.expectedErr) { + t.Fatalf("Expected JS error to contain %q, got %q", subtest.expectedErr, et) + } + } + + ingress := e.Ingress() + require_True(t, ingress != nil) + // We will focus only on the trace message that + // includes the JetStream event. + switch mainIter { + case 0: + require_Equal[string](t, e.Server.Name, s.Name()) + require_True(t, ingress.Kind == CLIENT) + require_Equal[string](t, ingress.Name, "Tracer") + require_Equal[int](t, len(e.Egresses()), 0) + checkJS() + case 1: + if ingress.Kind == ROUTER { + require_Equal[string](t, e.Server.Name, slSrv.Name()) + require_Equal[int](t, len(e.Egresses()), 0) + require_True(t, e.JetStream() != nil) + checkJS() + } + case 2: + switch ingress.Kind { + case GATEWAY: + require_Equal[string](t, ingress.Name, s.Name()) + // If the server that emitted this event is the + // stream leader, then we should have the stream, + // otherwise, it should be routed. + if e.Server.Name == slSrv.Name() { + require_Equal[int](t, len(e.Egresses()), 0) + checkJS() + } else { + return true + } + case ROUTER: + require_Equal[string](t, e.Server.Name, slSrv.Name()) + require_Equal[int](t, len(e.Egresses()), 0) + checkJS() + } + } + return false + } + checkJSTrace() + if mainIter > 0 { + if checkJSTrace() { + checkJSTrace() + } + } + }) + } + }) + } + + // Now cause a step-down, and verify count is as expected. + for _, stream := range []string{"TEST1", "TEST2", "TEST3"} { + _, err := nc.Request(fmt.Sprintf(JSApiStreamLeaderStepDownT, stream), nil, time.Second) + require_NoError(t, err) + sc.waitOnStreamLeader(globalAccountName, stream) + checkStream(t, stream, 7) + } + + s := c1.randomNonStreamLeader(globalAccountName, "TEST1") + // Try to get a message that will come from a route and make sure that + // this does not trigger a trace message, that is, that headers have + // been properly removed so that they don't trigger it. + nct := natsConnect(t, s.ClientURL(), nats.Name("Tracer")) + defer nct.Close() + traceSub := natsSubSync(t, nct, "my.trace.subj") + natsFlush(t, nct) + + jct, err := nct.JetStream() + require_NoError(t, err) + + sub, err := jct.SubscribeSync("TEST1") + require_NoError(t, err) + for i := 0; i < 7; i++ { + jmsg, err := sub.NextMsg(time.Second) + require_NoError(t, err) + require_Equal[string](t, jmsg.Header.Get(MsgTraceSendTo), _EMPTY_) + } + + msg, err := traceSub.NextMsg(250 * time.Millisecond) + if err != nats.ErrTimeout { + if msg != nil { + t.Fatalf("Expected timeout, got msg headers=%+v data=%s", msg.Header, msg.Data) + } + t.Fatalf("Expected timeout, got err=%v", err) + } +} + +func TestMsgTraceWithCompression(t *testing.T) { + o := DefaultOptions() + s := RunServer(o) + defer s.Shutdown() + + nc := natsConnect(t, s.ClientURL()) + defer nc.Close() + + traceSub := natsSubSync(t, nc, "my.trace.subj") + natsFlush(t, nc) + + for _, test := range []struct { + compressAlgo string + expectedHdr string + unsupported bool + }{ + {"gzip", "gzip", false}, + {"snappy", "snappy", false}, + {"s2", "snappy", false}, + {"bad one", "identity", true}, + } { + t.Run(test.compressAlgo, func(t *testing.T) { + msg := nats.NewMsg("foo") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + msg.Header.Set(acceptEncodingHeader, test.compressAlgo) + msg.Data = []byte("hello!") + err := nc.PublishMsg(msg) + require_NoError(t, err) + + traceMsg := natsNexMsg(t, traceSub, time.Second) + data := traceMsg.Data + eh := traceMsg.Header.Get(contentEncodingHeader) + require_Equal[string](t, eh, test.expectedHdr) + if test.unsupported { + // We should be able to unmarshal directly + } else { + switch test.expectedHdr { + case "gzip": + zr, err := gzip.NewReader(bytes.NewReader(data)) + require_NoError(t, err) + data, err = io.ReadAll(zr) + if err != nil && err != io.ErrUnexpectedEOF { + t.Fatalf("Unexpected error: %v", err) + } + err = zr.Close() + require_NoError(t, err) + case "snappy": + sr := s2.NewReader(bytes.NewReader(data)) + data, err = io.ReadAll(sr) + if err != nil && err != io.ErrUnexpectedEOF { + t.Fatalf("Unexpected error: %v", err) + } + } + } + var e MsgTraceEvent + err = json.Unmarshal(data, &e) + require_NoError(t, err) + ingress := e.Ingress() + require_True(t, ingress != nil) + require_Equal[string](t, e.Server.Name, s.Name()) + require_Equal[string](t, ingress.Subject, "foo") + }) + } +} + +func TestMsgTraceHops(t *testing.T) { + // Will have a test with following toplogy + // + // =================== =================== + // = C1 cluster = = C2 cluster = + // =================== <--- Gateway ---> =================== + // = C1-S1 <-> C1-S2 = = C2-S1 = + // =================== =================== + // ^ ^ ^ + // | Leafnode | | Leafnode + // | | | + // =================== =================== + // = C3 cluster = = C4 cluster = + // =================== =================== + // = C3-S1 <-> C3-S2 = = C4-S1 = + // =================== =================== + // ^ + // | Leafnode + // |-------| + // =================== + // = C5 cluster = + // =================== + // = C5-S1 <-> C5-S2 = + // =================== + // + // And a subscription on "foo" attached to all servers, and the subscription + // on the trace subject attached to c1-s1 (and where the trace message will + // be sent from). + // + commonTmpl := ` + port: -1 + server_name: "%s-%s" + accounts { + A { users: [{user:"a", pass: "pwd"}] } + $SYS { users: [{user:"admin", pass: "s3cr3t!"}] } + } + system_account: "$SYS" + cluster { + port: -1 + name: "%s" + %s + } + ` + genCommon := func(cname, sname string, routePort int) string { + var routes string + if routePort > 0 { + routes = fmt.Sprintf(`routes: ["nats://127.0.0.1:%d"]`, routePort) + } + return fmt.Sprintf(commonTmpl, cname, sname, cname, routes) + } + c1s1conf := createConfFile(t, []byte(fmt.Sprintf(` + %s + gateway { + port: -1 + name: "C1" + } + leafnodes { + port: -1 + } + `, genCommon("C1", "S1", 0)))) + c1s1, oc1s1 := RunServerWithConfig(c1s1conf) + defer c1s1.Shutdown() + + c1s2conf := createConfFile(t, []byte(fmt.Sprintf(` + %s + gateway { + port: -1 + name: "C1" + } + leafnodes { + port: -1 + } + `, genCommon("C1", "S2", oc1s1.Cluster.Port)))) + c1s2, oc1s2 := RunServerWithConfig(c1s2conf) + defer c1s2.Shutdown() + + checkClusterFormed(t, c1s1, c1s2) + + c2s1conf := createConfFile(t, []byte(fmt.Sprintf(` + %s + gateway { + port: -1 + name: "C2" + gateways [ + { + name: "C1" + url: "nats://a:pwd@127.0.0.1:%d" + } + ] + } + leafnodes { + port: -1 + } + `, genCommon("C2", "S1", 0), oc1s1.Gateway.Port))) + c2s1, oc2s1 := RunServerWithConfig(c2s1conf) + defer c2s1.Shutdown() + + c4s1conf := createConfFile(t, []byte(fmt.Sprintf(` + %s + leafnodes { + remotes [{url: "nats://a:pwd@127.0.0.1:%d", account: "A"}] + } + `, genCommon("C4", "S1", 0), oc2s1.LeafNode.Port))) + c4s1, _ := RunServerWithConfig(c4s1conf) + defer c4s1.Shutdown() + + for _, s := range []*Server{c1s1, c1s2, c2s1} { + waitForOutboundGateways(t, s, 1, time.Second) + } + waitForInboundGateways(t, c2s1, 2, time.Second) + + c3s1conf := createConfFile(t, []byte(fmt.Sprintf(` + %s + leafnodes { + port: -1 + remotes [{url: "nats://a:pwd@127.0.0.1:%d", account: "A"}] + } + `, genCommon("C3", "S1", 0), oc1s1.LeafNode.Port))) + c3s1, oc3s1 := RunServerWithConfig(c3s1conf) + defer c3s1.Shutdown() + + c3s2conf := createConfFile(t, []byte(fmt.Sprintf(` + %s + leafnodes { + port: -1 + remotes [{url: "nats://a:pwd@127.0.0.1:%d", account: "A"}] + } + system_account: "$SYS" + `, genCommon("C3", "S2", oc3s1.Cluster.Port), oc1s2.LeafNode.Port))) + c3s2, oc3s2 := RunServerWithConfig(c3s2conf) + defer c3s2.Shutdown() + + checkClusterFormed(t, c3s1, c3s2) + checkLeafNodeConnected(t, c1s1) + checkLeafNodeConnected(t, c1s2) + checkLeafNodeConnected(t, c3s1) + checkLeafNodeConnected(t, c3s2) + + c5s1conf := createConfFile(t, []byte(fmt.Sprintf(` + %s + leafnodes { + remotes [{url: "nats://a:pwd@127.0.0.1:%d", account: "A"}] + } + `, genCommon("C5", "S1", 0), oc3s2.LeafNode.Port))) + c5s1, oc5s1 := RunServerWithConfig(c5s1conf) + defer c5s1.Shutdown() + + c5s2conf := createConfFile(t, []byte(fmt.Sprintf(` + %s + leafnodes { + remotes [{url: "nats://a:pwd@127.0.0.1:%d", account: "A"}] + } + `, genCommon("C5", "S2", oc5s1.Cluster.Port), oc3s2.LeafNode.Port))) + c5s2, _ := RunServerWithConfig(c5s2conf) + defer c5s2.Shutdown() + + checkLeafNodeConnected(t, c5s1) + checkLeafNodeConnected(t, c5s2) + checkLeafNodeConnectedCount(t, c3s2, 3) + + nct := natsConnect(t, c1s1.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("Tracer")) + defer nct.Close() + traceSub := natsSubSync(t, nct, "my.trace.subj") + natsFlush(t, nct) + + allServers := []*Server{c1s1, c1s2, c2s1, c3s1, c3s2, c4s1, c5s1, c5s2} + // Check that the subscription interest on the trace subject reaches all servers. + for _, s := range allServers { + if s == c2s1 { + // Gateway needs to be checked differently. + checkGWInterestOnlyModeInterestOn(t, c2s1, "C1", "A", traceSub.Subject) + continue + } + checkSubInterest(t, s, "A", traceSub.Subject, time.Second) + } + + var subs []*nats.Subscription + // Now create a subscription on "foo" on all servers (do in reverse order). + for i := len(allServers) - 1; i >= 0; i-- { + s := allServers[i] + nc := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name(fmt.Sprintf("sub%d", i+1))) + defer nc.Close() + subs = append(subs, natsSubSync(t, nc, "foo")) + natsFlush(t, nc) + } + + // Check sub interest on "foo" on all servers. + for _, s := range allServers { + checkSubInterest(t, s, "A", "foo", time.Second) + } + + // Now send a trace message from c1s1 + msg := nats.NewMsg("foo") + msg.Header.Set(MsgTraceSendTo, traceSub.Subject) + msg.Data = []byte("hello!") + err := nct.PublishMsg(msg) + require_NoError(t, err) + + // Check that all subscriptions received the message + for i, sub := range subs { + appMsg, err := sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting app message for server %q", allServers[i]) + } + require_Equal[string](t, string(appMsg.Data), "hello!") + // Check that no (more) messages are received. + if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Did not expect application message, got %s", msg.Data) + } + } + + events := make(map[string]*MsgTraceEvent, 8) + // We expect 8 events + for i := 0; i < 8; i++ { + traceMsg := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + json.Unmarshal(traceMsg.Data, &e) + + hop := e.Request.Header.Get(MsgTraceHop) + events[hop] = &e + } + // Make sure we are not receiving more traces + if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { + t.Fatalf("Should not have received trace message: %s", tm.Data) + } + + checkIngress := func(e *MsgTraceEvent, kind int, name, hop string) *MsgTraceIngress { + t.Helper() + ingress := e.Ingress() + require_True(t, ingress != nil) + require_True(t, ingress.Kind == kind) + require_Equal[string](t, ingress.Account, "A") + require_Equal[string](t, ingress.Subject, "foo") + require_Equal[string](t, ingress.Name, name) + require_Equal[string](t, e.Request.Header.Get(MsgTraceHop), hop) + return ingress + } + + checkEgressClient := func(eg *MsgTraceEgress, name string) { + t.Helper() + require_True(t, eg.Kind == CLIENT) + require_Equal[string](t, eg.Name, name) + require_Equal[string](t, eg.Hop, _EMPTY_) + require_Equal[string](t, eg.Subscription, "foo") + require_Equal[string](t, eg.Queue, _EMPTY_) + } + + // First, we should have an event without a "hop" header, that is the + // ingress from the client. + e, ok := events[_EMPTY_] + require_True(t, ok) + checkIngress(e, CLIENT, "Tracer", _EMPTY_) + require_Equal[int](t, e.Hops, 3) + egress := e.Egresses() + require_Equal[int](t, len(egress), 4) + var ( + leafC3S1Hop string + leafC3S2Hop string + leafC4S1Hop string + leafC5S1Hop string + leafC5S2Hop string + routeC1S2Hop string + gwC2S1Hop string + ) + for _, eg := range egress { + switch eg.Kind { + case CLIENT: + checkEgressClient(eg, "sub1") + case ROUTER: + require_Equal[string](t, eg.Name, c1s2.Name()) + routeC1S2Hop = eg.Hop + case LEAF: + require_Equal[string](t, eg.Name, c3s1.Name()) + leafC3S1Hop = eg.Hop + case GATEWAY: + require_Equal[string](t, eg.Name, c2s1.Name()) + gwC2S1Hop = eg.Hop + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + // All "hop" ids should be not empty and different from each other + require_True(t, leafC3S1Hop != _EMPTY_ && routeC1S2Hop != _EMPTY_ && gwC2S1Hop != _EMPTY_) + require_True(t, leafC3S1Hop != routeC1S2Hop && leafC3S1Hop != gwC2S1Hop && routeC1S2Hop != gwC2S1Hop) + + // Now check the routed server in C1 (c1s2) + e, ok = events[routeC1S2Hop] + require_True(t, ok) + checkIngress(e, ROUTER, c1s1.Name(), routeC1S2Hop) + require_Equal[int](t, e.Hops, 1) + egress = e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + switch eg.Kind { + case CLIENT: + checkEgressClient(eg, "sub2") + case LEAF: + require_Equal[string](t, eg.Name, c3s2.Name()) + require_Equal[string](t, eg.Hop, routeC1S2Hop+".1") + leafC3S2Hop = eg.Hop + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + require_True(t, leafC3S2Hop != _EMPTY_) + + // Let's check the gateway server + e, ok = events[gwC2S1Hop] + require_True(t, ok) + checkIngress(e, GATEWAY, c1s1.Name(), gwC2S1Hop) + require_Equal[int](t, e.Hops, 1) + egress = e.Egresses() + require_Equal[int](t, len(egress), 2) + for _, eg := range egress { + switch eg.Kind { + case CLIENT: + checkEgressClient(eg, "sub3") + case LEAF: + require_Equal[string](t, eg.Name, c4s1.Name()) + require_Equal[string](t, eg.Hop, gwC2S1Hop+".1") + leafC4S1Hop = eg.Hop + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + require_True(t, leafC4S1Hop != _EMPTY_) + + // Let's check the C3 cluster, starting at C3-S1 + e, ok = events[leafC3S1Hop] + require_True(t, ok) + checkIngress(e, LEAF, c1s1.Name(), leafC3S1Hop) + require_Equal[int](t, e.Hops, 0) + egress = e.Egresses() + require_Equal[int](t, len(egress), 1) + checkEgressClient(egress[0], "sub4") + + // Now C3-S2 + e, ok = events[leafC3S2Hop] + require_True(t, ok) + checkIngress(e, LEAF, c1s2.Name(), leafC3S2Hop) + require_Equal[int](t, e.Hops, 2) + egress = e.Egresses() + require_Equal[int](t, len(egress), 3) + for _, eg := range egress { + switch eg.Kind { + case CLIENT: + checkEgressClient(eg, "sub5") + case LEAF: + require_True(t, eg.Name == c5s1.Name() || eg.Name == c5s2.Name()) + require_True(t, eg.Hop == leafC3S2Hop+".1" || eg.Hop == leafC3S2Hop+".2") + if eg.Name == c5s1.Name() { + leafC5S1Hop = eg.Hop + } else { + leafC5S2Hop = eg.Hop + } + default: + t.Fatalf("Unexpected egress: %+v", eg) + } + } + // The leafC5SxHop must be different and not empty + require_True(t, leafC5S1Hop != _EMPTY_ && leafC5S1Hop != leafC5S2Hop && leafC5S2Hop != _EMPTY_) + + // Check the C4 cluster + e, ok = events[leafC4S1Hop] + require_True(t, ok) + checkIngress(e, LEAF, c2s1.Name(), leafC4S1Hop) + require_Equal[int](t, e.Hops, 0) + egress = e.Egresses() + require_Equal[int](t, len(egress), 1) + checkEgressClient(egress[0], "sub6") + + // Finally, the C5 cluster, starting with C5-S1 + e, ok = events[leafC5S1Hop] + require_True(t, ok) + checkIngress(e, LEAF, c3s2.Name(), leafC5S1Hop) + require_Equal[int](t, e.Hops, 0) + egress = e.Egresses() + require_Equal[int](t, len(egress), 1) + checkEgressClient(egress[0], "sub7") + + // Then C5-S2 + e, ok = events[leafC5S2Hop] + require_True(t, ok) + checkIngress(e, LEAF, c3s2.Name(), leafC5S2Hop) + require_Equal[int](t, e.Hops, 0) + egress = e.Egresses() + require_Equal[int](t, len(egress), 1) + checkEgressClient(egress[0], "sub8") +} diff --git a/server/opts.go b/server/opts.go index 721d6be8436..73635f95c9e 100644 --- a/server/opts.go +++ b/server/opts.go @@ -402,7 +402,7 @@ type Options struct { // private fields, used for testing gatewaysSolicitDelay time.Duration - routeProto int + overrideProto int // JetStream maxMemSet bool @@ -2709,14 +2709,16 @@ type export struct { lat *serviceLatency rthr time.Duration tPos uint + atrc bool // allow_trace } type importStream struct { - acc *Account - an string - sub string - to string - pre string + acc *Account + an string + sub string + to string + pre string + atrc bool // allow_trace } type importService struct { @@ -3147,6 +3149,14 @@ func parseAccounts(v interface{}, opts *Options, errors *[]error, warnings *[]er continue } } + + if service.atrc { + if err := service.acc.SetServiceExportAllowTrace(service.sub, true); err != nil { + msg := fmt.Sprintf("Error adding allow_trace for %q: %v", service.sub, err) + *errors = append(*errors, &configErr{tk, msg}) + continue + } + } } for _, stream := range importStreams { ta := am[stream.an] @@ -3156,13 +3166,13 @@ func parseAccounts(v interface{}, opts *Options, errors *[]error, warnings *[]er continue } if stream.pre != _EMPTY_ { - if err := stream.acc.AddStreamImport(ta, stream.sub, stream.pre); err != nil { + if err := stream.acc.addStreamImportWithClaim(ta, stream.sub, stream.pre, stream.atrc, nil); err != nil { msg := fmt.Sprintf("Error adding stream import %q: %v", stream.sub, err) *errors = append(*errors, &configErr{tk, msg}) continue } } else { - if err := stream.acc.AddMappedStreamImport(ta, stream.sub, stream.to); err != nil { + if err := stream.acc.addMappedStreamImportWithClaim(ta, stream.sub, stream.to, stream.atrc, nil); err != nil { msg := fmt.Sprintf("Error adding stream import %q: %v", stream.sub, err) *errors = append(*errors, &configErr{tk, msg}) continue @@ -3320,6 +3330,9 @@ func parseExportStreamOrService(v interface{}, errors, warnings *[]error) (*expo latToken token lt token accTokPos uint + atrc bool + atrcSeen bool + atrcToken token ) defer convertPanicToErrorList(<, errors) @@ -3347,6 +3360,11 @@ func parseExportStreamOrService(v interface{}, errors, warnings *[]error) (*expo *errors = append(*errors, err) continue } + if atrcToken != nil { + err := &configErr{atrcToken, "Detected allow_trace directive on non-service"} + *errors = append(*errors, err) + continue + } mvs, ok := mv.(string) if !ok { err := &configErr{tk, fmt.Sprintf("Expected stream name to be string, got %T", mv)} @@ -3382,6 +3400,9 @@ func parseExportStreamOrService(v interface{}, errors, warnings *[]error) (*expo if threshSeen { curService.rthr = thresh } + if atrcSeen { + curService.atrc = atrc + } case "response", "response_type": if rtSeen { err := &configErr{tk, "Duplicate response type definition"} @@ -3470,6 +3491,18 @@ func parseExportStreamOrService(v interface{}, errors, warnings *[]error) (*expo } case "account_token_position": accTokPos = uint(mv.(int64)) + case "allow_trace": + atrcSeen = true + atrcToken = tk + atrc = mv.(bool) + if curStream != nil { + *errors = append(*errors, + &configErr{tk, "Detected allow_trace directive on non-service"}) + continue + } + if curService != nil { + curService.atrc = atrc + } default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -3580,6 +3613,9 @@ func parseImportStreamOrService(v interface{}, errors, warnings *[]error) (*impo pre, to string share bool lt token + atrc bool + atrcSeen bool + atrcToken token ) defer convertPanicToErrorList(<, errors) @@ -3621,12 +3657,20 @@ func parseImportStreamOrService(v interface{}, errors, warnings *[]error) (*impo if pre != _EMPTY_ { curStream.pre = pre } + if atrcSeen { + curStream.atrc = atrc + } case "service": if curStream != nil { err := &configErr{tk, "Detected service but already saw a stream"} *errors = append(*errors, err) continue } + if atrcToken != nil { + err := &configErr{atrcToken, "Detected allow_trace directive on a non-stream"} + *errors = append(*errors, err) + continue + } ac, ok := mv.(map[string]interface{}) if !ok { err := &configErr{tk, fmt.Sprintf("Service entry should be an account map, got %T", mv)} @@ -3674,6 +3718,18 @@ func parseImportStreamOrService(v interface{}, errors, warnings *[]error) (*impo if curService != nil { curService.share = share } + case "allow_trace": + if curService != nil { + err := &configErr{tk, "Detected allow_trace directive on a non-stream"} + *errors = append(*errors, err) + continue + } + atrcSeen = true + atrc = mv.(bool) + atrcToken = tk + if curStream != nil { + curStream.atrc = atrc + } default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ diff --git a/server/parser.go b/server/parser.go index 74f55f576d2..101e352d6ea 100644 --- a/server/parser.go +++ b/server/parser.go @@ -1,4 +1,4 @@ -// Copyright 2012-2020 The NATS Authors +// Copyright 2012-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -49,6 +49,7 @@ type pubArg struct { size int hdr int psi []*serviceImport + trace *msgTrace } // Parser constants @@ -285,7 +286,11 @@ func (c *client) parse(buf []byte) error { if trace { c.traceInOp("HPUB", arg) } - if err := c.processHeaderPub(arg); err != nil { + var remaining []byte + if i < len(buf) { + remaining = buf[i+1:] + } + if err := c.processHeaderPub(arg, remaining); err != nil { return err } @@ -483,11 +488,17 @@ func (c *client) parse(buf []byte) error { c.msgBuf = buf[c.as : i+1] } + var mt *msgTrace + if c.pa.hdr > 0 { + mt = c.initMsgTrace() + } // Check for mappings. if (c.kind == CLIENT || c.kind == LEAF) && c.in.flags.isSet(hasMappings) { changed := c.selectMappedSubject() - if trace && changed { + if (trace || mt != nil) && changed { c.traceInOp("MAPPING", []byte(fmt.Sprintf("%s -> %s", c.pa.mapped, c.pa.subject))) + // c.pa.subject is the subject the original is now mapped to. + mt.addSubjectMappingEvent(c.pa.subject) } } if trace { @@ -495,11 +506,14 @@ func (c *client) parse(buf []byte) error { } c.processInboundMsg(c.msgBuf) + + mt.sendEvent() c.argBuf, c.msgBuf, c.header = nil, nil, nil c.drop, c.as, c.state = 0, i+1, OP_START // Drop all pub args c.pa.arg, c.pa.pacache, c.pa.origin, c.pa.account, c.pa.subject, c.pa.mapped = nil, nil, nil, nil, nil, nil c.pa.reply, c.pa.hdr, c.pa.size, c.pa.szb, c.pa.hdb, c.pa.queues = nil, -1, 0, nil, nil, nil + c.pa.trace = nil lmsg = false case OP_A: switch b { @@ -1270,7 +1284,7 @@ func (c *client) clonePubArg(lmsg bool) error { if c.pa.hdr < 0 { return c.processPub(c.argBuf) } else { - return c.processHeaderPub(c.argBuf) + return c.processHeaderPub(c.argBuf, nil) } } } diff --git a/server/parser_test.go b/server/parser_test.go index 458d76cee5e..47ec24c7659 100644 --- a/server/parser_test.go +++ b/server/parser_test.go @@ -1,4 +1,4 @@ -// Copyright 2012-2020 The NATS Authors +// Copyright 2012-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -424,7 +424,7 @@ func TestParseHeaderPubArg(t *testing.T) { {arg: "\t \tfoo\t \t \t\t\t11\t\t 2222\t \t", subject: "foo", reply: "", hdr: 11, size: 2222, szb: "2222"}, } { t.Run(test.arg, func(t *testing.T) { - if err := c.processHeaderPub([]byte(test.arg)); err != nil { + if err := c.processHeaderPub([]byte(test.arg), nil); err != nil { t.Fatalf("Unexpected parse error: %v\n", err) } if !bytes.Equal(c.pa.subject, []byte(test.subject)) { diff --git a/server/reload_test.go b/server/reload_test.go index 2bc22b6ee58..52e0fa87798 100644 --- a/server/reload_test.go +++ b/server/reload_test.go @@ -2557,7 +2557,7 @@ func TestConfigReloadClusterPermsOldServer(t *testing.T) { optsB := DefaultOptions() optsB.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", srva.ClusterAddr().Port)) // Make server B behave like an old server - optsB.routeProto = setRouteProtoForTest(RouteProtoZero) + optsB.overrideProto = setServerProtoForTest(RouteProtoZero) srvb := RunServer(optsB) defer srvb.Shutdown() diff --git a/server/route.go b/server/route.go index f8a8623d6a3..b154f309cc2 100644 --- a/server/route.go +++ b/server/route.go @@ -1,4 +1,4 @@ -// Copyright 2013-2023 The NATS Authors +// Copyright 2013-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -42,17 +42,6 @@ const ( Explicit ) -const ( - // RouteProtoZero is the original Route protocol from 2009. - // http://nats.io/documentation/internals/nats-protocol/ - RouteProtoZero = iota - // RouteProtoInfo signals a route can receive more then the original INFO block. - // This can be used to update remote cluster permissions, etc... - RouteProtoInfo - // RouteProtoV2 is the new route/cluster protocol that provides account support. - RouteProtoV2 -) - // Include the space for the proto var ( aSubBytes = []byte{'A', '+', ' '} @@ -63,11 +52,6 @@ var ( lUnsubBytes = []byte{'L', 'S', '-', ' '} ) -// Used by tests -func setRouteProtoForTest(wantedProto int) int { - return (wantedProto + 1) * -1 -} - type route struct { remoteID string remoteName string @@ -754,7 +738,8 @@ func (c *client) processRouteInfo(info *Info) { // Mark that the INFO protocol has been received, so we can detect updates. c.flags.set(infoReceived) - // Get the route's proto version + // Get the route's proto version. It will be used to check if the connection + // supports certain features, such as message tracing. c.opts.Protocol = info.Proto // Headers @@ -2457,17 +2442,6 @@ func (s *Server) startRouteAcceptLoop() { s.Noticef("Listening for route connections on %s", net.JoinHostPort(opts.Cluster.Host, strconv.Itoa(l.Addr().(*net.TCPAddr).Port))) - proto := RouteProtoV2 - // For tests, we want to be able to make this server behave - // as an older server so check this option to see if we should override - if opts.routeProto < 0 { - // We have a private option that allows test to override the route - // protocol. We want this option initial value to be 0, however, - // since original proto is RouteProtoZero, tests call setRouteProtoForTest(), - // which sets as negative value the (desired proto + 1) * -1. - // Here we compute back the real value. - proto = (opts.routeProto * -1) - 1 - } // Check for TLSConfig tlsReq := opts.Cluster.TLSConfig != nil info := Info{ @@ -2480,7 +2454,7 @@ func (s *Server) startRouteAcceptLoop() { TLSVerify: tlsReq, MaxPayload: s.info.MaxPayload, JetStream: s.info.JetStream, - Proto: proto, + Proto: s.getServerProto(), GatewayURL: s.getGatewayURL(), Headers: s.supportsHeaders(), Cluster: s.info.Cluster, diff --git a/server/server.go b/server/server.go index 75d6f393953..f437cbceca6 100644 --- a/server/server.go +++ b/server/server.go @@ -58,6 +58,49 @@ const ( firstClientPingInterval = 2 * time.Second ) +// These are protocol versions sent between server connections: ROUTER, LEAF and +// GATEWAY. We may have protocol versions that have a meaning only for a certain +// type of connections, but we don't have to have separate enums for that. +// However, it is CRITICAL to not change the order of those constants since they +// are exchanged between servers. When adding a new protocol version, add to the +// end of the list, don't try to group them by connection types. +const ( + // RouteProtoZero is the original Route protocol from 2009. + // http://nats.io/documentation/internals/nats-protocol/ + RouteProtoZero = iota + // RouteProtoInfo signals a route can receive more then the original INFO block. + // This can be used to update remote cluster permissions, etc... + RouteProtoInfo + // RouteProtoV2 is the new route/cluster protocol that provides account support. + RouteProtoV2 + // MsgTraceProto indicates that this server understands distributed message tracing. + MsgTraceProto +) + +// Will return the latest server-to-server protocol versions, unless the +// option to override it is set. +func (s *Server) getServerProto() int { + opts := s.getOpts() + // Initialize with the latest protocol version. + proto := MsgTraceProto + // For tests, we want to be able to make this server behave + // as an older server so check this option to see if we should override. + if opts.overrideProto < 0 { + // The option overrideProto is set to 0 by default (when creating an + // Options structure). Since this is the same value than the original + // proto RouteProtoZero, tests call setServerProtoForTest() with the + // desired protocol level, which sets it as negative value equal to: + // (wantedProto + 1) * -1. Here we compute back the real value. + proto = (opts.overrideProto * -1) - 1 + } + return proto +} + +// Used by tests. +func setServerProtoForTest(wantedProto int) int { + return (wantedProto + 1) * -1 +} + // Info is the information sent to clients, routes, gateways, and leaf nodes, // to help them understand information about this server. type Info struct { diff --git a/server/stream.go b/server/stream.go index 18b7ef229bd..c6e9102d61f 100644 --- a/server/stream.go +++ b/server/stream.go @@ -268,6 +268,9 @@ type stream struct { sch chan struct{} sigq *ipQueue[*cMsg] csl *Sublist // Consumer Sublist + // Leader will store seq/msgTrace in clustering mode. Used in applyStreamEntries + // to know if trace event should be sent after processing. + mt map[uint64]*msgTrace // For non limits policy streams when they process an ack before the actual msg. // Can happen in stretch clusters, multi-cloud, or during catchup for a restarted server. @@ -2304,7 +2307,7 @@ func (mset *stream) processInboundMirrorMsg(m *inMsg) bool { err = node.Propose(encodeStreamMsg(m.subj, _EMPTY_, m.hdr, m.msg, sseq-1, ts)) } } else { - err = mset.processJetStreamMsg(m.subj, _EMPTY_, m.hdr, m.msg, sseq-1, ts) + err = mset.processJetStreamMsg(m.subj, _EMPTY_, m.hdr, m.msg, sseq-1, ts, nil) } if err != nil { if strings.Contains(err.Error(), "no space left") { @@ -2650,7 +2653,7 @@ func (mset *stream) setupMirrorConsumer() error { msgs := mirror.msgs sub, err := mset.subscribeInternal(deliverSubject, func(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { hdr, msg := c.msgParts(copyBytes(rmsg)) // Need to copy. - mset.queueInbound(msgs, subject, reply, hdr, msg) + mset.queueInbound(msgs, subject, reply, hdr, msg, nil) }) if err != nil { mirror.err = NewJSMirrorConsumerSetupFailedError(err, Unless(err)) @@ -3020,7 +3023,7 @@ func (mset *stream) setSourceConsumer(iname string, seq uint64, startTime time.T msgs := si.msgs sub, err := mset.subscribeInternal(deliverSubject, func(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { hdr, msg := c.msgParts(copyBytes(rmsg)) // Need to copy. - mset.queueInbound(msgs, subject, reply, hdr, msg) + mset.queueInbound(msgs, subject, reply, hdr, msg, nil) }) if err != nil { si.err = NewJSSourceConsumerSetupFailedError(err, Unless(err)) @@ -3249,9 +3252,9 @@ func (mset *stream) processInboundSourceMsg(si *sourceInfo, m *inMsg) bool { var err error // If we are clustered we need to propose this message to the underlying raft group. if node != nil { - err = mset.processClusteredInboundMsg(m.subj, _EMPTY_, hdr, msg) + err = mset.processClusteredInboundMsg(m.subj, _EMPTY_, hdr, msg, nil) } else { - err = mset.processJetStreamMsg(m.subj, _EMPTY_, hdr, msg, 0, 0) + err = mset.processJetStreamMsg(m.subj, _EMPTY_, hdr, msg, 0, 0, nil) } if err != nil { @@ -3978,21 +3981,11 @@ type inMsg struct { rply string hdr []byte msg []byte + mt *msgTrace } -func (mset *stream) queueInbound(ib *ipQueue[*inMsg], subj, rply string, hdr, msg []byte) { - ib.push(&inMsg{subj, rply, hdr, msg}) -} - -func (mset *stream) queueInboundMsg(subj, rply string, hdr, msg []byte) { - // Copy these. - if len(hdr) > 0 { - hdr = copyBytes(hdr) - } - if len(msg) > 0 { - msg = copyBytes(msg) - } - mset.queueInbound(mset.msgs, subj, rply, hdr, msg) +func (mset *stream) queueInbound(ib *ipQueue[*inMsg], subj, rply string, hdr, msg []byte, mt *msgTrace) { + ib.push(&inMsg{subj, rply, hdr, msg, mt}) } var dgPool = sync.Pool{ @@ -4183,7 +4176,27 @@ func (mset *stream) getDirectRequest(req *JSApiMsgGetRequest, reply string) { // processInboundJetStreamMsg handles processing messages bound for a stream. func (mset *stream) processInboundJetStreamMsg(_ *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { hdr, msg := c.msgParts(rmsg) - mset.queueInboundMsg(subject, reply, hdr, msg) + // Copy these. + if len(hdr) > 0 { + hdr = copyBytes(hdr) + } + if len(msg) > 0 { + msg = copyBytes(msg) + } + if mt, traceOnly := c.isMsgTraceEnabled(); mt != nil { + // If message is delivered, we need to disable the message trace destination + // header to prevent a trace event to be generated when a stored message + // is delivered to a consumer and routed. + if !traceOnly { + mt.disableTraceHeader(c, hdr) + } + // This will add the jetstream event while in the client read loop. + // Since the event will be updated in a different go routine, the + // tracing object will have a separate reference to the JS trace + // object. + mt.addJetStreamEvent(mset.name()) + } + mset.queueInbound(mset.msgs, subject, reply, hdr, msg, c.pa.trace) } var ( @@ -4194,7 +4207,15 @@ var ( ) // processJetStreamMsg is where we try to actually process the stream msg. -func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, lseq uint64, ts int64) error { +func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, lseq uint64, ts int64, mt *msgTrace) (retErr error) { + if mt != nil { + // Only the leader/standalone will have mt!=nil. On exit, send the + // message trace event. + defer func() { + mt.sendEventFromJetStream(retErr) + }() + } + if mset.closed.Load() { return errStreamClosed } @@ -4202,6 +4223,15 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, mset.mu.Lock() s, store := mset.srv, mset.store + traceOnly := mt.traceOnly() + bumpCLFS := func() { + // Do not bump if tracing and not doing message delivery. + if traceOnly { + return + } + mset.clfs++ + } + // Apply the input subject transform if any if mset.itr != nil { ts, err := mset.itr.Match(subject) @@ -4230,6 +4260,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, // Bail here if sealed. if isSealed { outq := mset.outq + bumpCLFS() mset.mu.Unlock() if canRespond && outq != nil { resp.PubAck = &PubAck{Stream: name} @@ -4292,10 +4323,12 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, isClustered := mset.isClustered() // Certain checks have already been performed if in clustered mode, so only check if not. - if !isClustered { + // Note, for cluster mode but with message tracing (without message delivery), we need + // to do this check here since it was not done in processClusteredInboundMsg(). + if !isClustered || traceOnly { // Expected stream. if sname := getExpectedStream(hdr); sname != _EMPTY_ && sname != name { - mset.clfs++ + bumpCLFS() mset.mu.Unlock() if canRespond { resp.PubAck = &PubAck{Stream: name} @@ -4310,7 +4343,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, // Dedupe detection. if msgId = getMsgId(hdr); msgId != _EMPTY_ { if dde := mset.checkMsgId(msgId); dde != nil { - mset.clfs++ + bumpCLFS() mset.mu.Unlock() if canRespond { response := append(pubAck, strconv.FormatUint(dde.seq, 10)...) @@ -4334,7 +4367,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, fseq, err = 0, nil } if err != nil || fseq != seq { - mset.clfs++ + bumpCLFS() mset.mu.Unlock() if canRespond { resp.PubAck = &PubAck{Stream: name} @@ -4349,7 +4382,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, // Expected last sequence. if seq, exists := getExpectedLastSeq(hdr); exists && seq != mset.lseq { mlseq := mset.lseq - mset.clfs++ + bumpCLFS() mset.mu.Unlock() if canRespond { resp.PubAck = &PubAck{Stream: name} @@ -4366,7 +4399,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, } if lmsgId != mset.lmsgId { last := mset.lmsgId - mset.clfs++ + bumpCLFS() mset.mu.Unlock() if canRespond { resp.PubAck = &PubAck{Stream: name} @@ -4380,7 +4413,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, // Check for any rollups. if rollup := getRollup(hdr); rollup != _EMPTY_ { if !mset.cfg.AllowRollup || mset.cfg.DenyPurge { - mset.clfs++ + bumpCLFS() mset.mu.Unlock() if canRespond { resp.PubAck = &PubAck{Stream: name} @@ -4396,8 +4429,16 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, case JSMsgRollupAll: rollupAll = true default: + bumpCLFS() mset.mu.Unlock() - return fmt.Errorf("rollup value invalid: %q", rollup) + err := fmt.Errorf("rollup value invalid: %q", rollup) + if canRespond { + resp.PubAck = &PubAck{Stream: name} + resp.Error = NewJSStreamRollupFailedError(err) + b, _ := json.Marshal(resp) + outq.sendMsg(reply, b) + } + return err } } } @@ -4411,7 +4452,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, // Check to see if we are over the max msg size. if maxMsgSize >= 0 && (len(hdr)+len(msg)) > maxMsgSize { - mset.clfs++ + bumpCLFS() mset.mu.Unlock() if canRespond { resp.PubAck = &PubAck{Stream: name} @@ -4423,7 +4464,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, } if len(hdr) > math.MaxUint16 { - mset.clfs++ + bumpCLFS() mset.mu.Unlock() if canRespond { resp.PubAck = &PubAck{Stream: name} @@ -4437,7 +4478,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, // Check to see if we have exceeded our limits. if js.limitsExceeded(stype) { s.resourcesExceededError() - mset.clfs++ + bumpCLFS() mset.mu.Unlock() if canRespond { resp.PubAck = &PubAck{Stream: name} @@ -4477,6 +4518,12 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, ts = time.Now().UnixNano() } + mt.updateJetStreamEvent(subject, noInterest) + if traceOnly { + mset.mu.Unlock() + return nil + } + // Skip msg here. if noInterest { mset.lseq = store.SkipMsg() @@ -4543,7 +4590,7 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte, mset.store.FastState(&state) mset.lseq = state.LastSeq mset.lmsgId = olmsgId - mset.clfs++ + bumpCLFS() mset.mu.Unlock() switch err { @@ -4946,9 +4993,9 @@ func (mset *stream) internalLoop() { for _, im := range ims { // If we are clustered we need to propose this message to the underlying raft group. if isClustered { - mset.processClusteredInboundMsg(im.subj, im.rply, im.hdr, im.msg) + mset.processClusteredInboundMsg(im.subj, im.rply, im.hdr, im.msg, im.mt) } else { - mset.processJetStreamMsg(im.subj, im.rply, im.hdr, im.msg, 0, 0) + mset.processJetStreamMsg(im.subj, im.rply, im.hdr, im.msg, 0, 0, im.mt) } } msgs.recycle(&ims) diff --git a/test/new_routes_test.go b/test/new_routes_test.go index 174c6d9a437..55004fdb60f 100644 --- a/test/new_routes_test.go +++ b/test/new_routes_test.go @@ -44,7 +44,7 @@ func TestNewRouteInfoOnConnect(t *testing.T) { // Make sure we advertise new proto. if info.Proto < server.RouteProtoV2 { - t.Fatalf("Expected routeProtoV2, got %d", info.Proto) + t.Fatalf("Expected routeProtoV2 or above, got %d", info.Proto) } // New proto should always send nonce too. if info.Nonce == "" {