From 37eb179c1041d1379468a596e0982e33cd90241e Mon Sep 17 00:00:00 2001 From: Neil Twigg Date: Tue, 25 Jun 2024 09:22:36 +0100 Subject: [PATCH 1/6] Import missing test helper changes from reverted PR Signed-off-by: Neil Twigg --- server/test_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/server/test_test.go b/server/test_test.go index 58af5d76f43..cf7c5e9baae 100644 --- a/server/test_test.go +++ b/server/test_test.go @@ -52,14 +52,14 @@ func RunRandClientPortServer(t *testing.T) *Server { return RunServer(&opts) } -func require_True(t *testing.T, b bool) { +func require_True(t testing.TB, b bool) { t.Helper() if !b { t.Fatalf("require true, but got false") } } -func require_False(t *testing.T, b bool) { +func require_False(t testing.TB, b bool) { t.Helper() if b { t.Fatalf("require false, but got true") @@ -89,7 +89,7 @@ func require_Contains(t *testing.T, s string, subStrs ...string) { } } -func require_Error(t *testing.T, err error, expected ...error) { +func require_Error(t testing.TB, err error, expected ...error) { t.Helper() if err == nil { t.Fatalf("require error, but got none") @@ -112,21 +112,21 @@ func require_Error(t *testing.T, err error, expected ...error) { t.Fatalf("Expected one of %v, got '%v'", expected, err) } -func require_Equal[T comparable](t *testing.T, a, b T) { +func require_Equal[T comparable](t testing.TB, a, b T) { t.Helper() if a != b { t.Fatalf("require %T equal, but got: %v != %v", a, a, b) } } -func require_NotEqual[T comparable](t *testing.T, a, b T) { +func require_NotEqual[T comparable](t testing.TB, a, b T) { t.Helper() if a == b { t.Fatalf("require %T not equal, but got: %v == %v", a, a, b) } } -func require_Len(t *testing.T, a, b int) { +func require_Len(t testing.TB, a, b int) { t.Helper() if a != b { t.Fatalf("require len, but got: %v != %v", a, b) From 4d4ba3287ab2c1b987cbac4a296a2f8d212606ca Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Sat, 22 Jun 2024 14:01:30 -0700 Subject: [PATCH 2/6] Make sure on a miss from a starting sequence that if no other msgs exists we avoid loading and blocks. Signed-off-by: Derek Collison --- server/filestore.go | 30 ++++++++++----- server/filestore_test.go | 81 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 10 deletions(-) diff --git a/server/filestore.go b/server/filestore.go index 689f0163e30..8620cd9a026 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -2596,28 +2596,38 @@ func (fs *fileStore) FilteredState(sseq uint64, subj string) SimpleState { // This is used to see if we can selectively jump start blocks based on filter subject and a floor block index. // Will return -1 if no matches at all. -func (fs *fileStore) checkSkipFirstBlock(filter string, wc bool) int { - start := uint32(math.MaxUint32) +func (fs *fileStore) checkSkipFirstBlock(filter string, wc bool) (int, int) { + start, stop := uint32(math.MaxUint32), uint32(0) if wc { fs.psim.Match(stringToBytes(filter), func(_ []byte, psi *psi) { if psi.fblk < start { start = psi.fblk } + if psi.lblk > stop { + stop = psi.lblk + } }) } else if psi, ok := fs.psim.Find(stringToBytes(filter)); ok { - start = psi.fblk + start, stop = psi.fblk, psi.lblk } // Nothing found. if start == uint32(math.MaxUint32) { - return -1 + return -1, -1 } - // Here we need to translate this to index into fs.blks. + // Here we need to translate this to index into fs.blks properly. mb := fs.bim[start] if mb == nil { - return -1 + return -1, -1 } - bi, _ := fs.selectMsgBlockWithIndex(atomic.LoadUint64(&mb.last.seq)) - return bi + fi, _ := fs.selectMsgBlockWithIndex(atomic.LoadUint64(&mb.last.seq)) + + mb = fs.bim[stop] + if mb == nil { + return -1, -1 + } + li, _ := fs.selectMsgBlockWithIndex(atomic.LoadUint64(&mb.last.seq)) + + return fi, li } // Optimized way for getting all num pending matching a filter subject. @@ -6362,9 +6372,9 @@ func (fs *fileStore) LoadNextMsg(filter string, wc bool, start uint64, sm *Store // Similar to above if start <= first seq. // TODO(dlc) - For v2 track these by filter subject since they will represent filtered consumers. if i == bi { - nbi := fs.checkSkipFirstBlock(filter, wc) + nbi, lbi := fs.checkSkipFirstBlock(filter, wc) // Nothing available. - if nbi < 0 { + if nbi < 0 || lbi <= bi { return nil, fs.state.LastSeq, ErrStoreEOF } // See if we can jump ahead here. diff --git a/server/filestore_test.go b/server/filestore_test.go index 84e9e979583..0789abbef7f 100644 --- a/server/filestore_test.go +++ b/server/filestore_test.go @@ -7162,6 +7162,58 @@ func TestFileStoreFilteredPendingPSIMFirstBlockUpdateNextBlock(t *testing.T) { require_Equal(t, psi.lblk, 4) } +func TestFileStoreLargeSparseMsgsDoNotLoadAfterLast(t *testing.T) { + sd := t.TempDir() + fs, err := newFileStore( + FileStoreConfig{StoreDir: sd, BlockSize: 128}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*.*"}, Storage: FileStorage}) + require_NoError(t, err) + defer fs.Stop() + + msg := []byte("hello") + // Create 2 blocks with each, each block holds 2 msgs + for i := 0; i < 2; i++ { + fs.StoreMsg("foo.22.bar", nil, msg) + fs.StoreMsg("foo.22.baz", nil, msg) + } + // Now create 8 more blocks with just baz. So no matches for these 8 blocks + // for "foo.22.bar". + for i := 0; i < 8; i++ { + fs.StoreMsg("foo.22.baz", nil, msg) + fs.StoreMsg("foo.22.baz", nil, msg) + } + require_Equal(t, fs.numMsgBlocks(), 10) + + // Remove all blk cache and fss. + fs.mu.RLock() + for _, mb := range fs.blks { + mb.mu.Lock() + mb.fss, mb.cache = nil, nil + mb.mu.Unlock() + } + fs.mu.RUnlock() + + // "foo.22.bar" is at sequence 1 and 3. + // Make sure if we do a LoadNextMsg() starting at 4 that we do not load + // all the tail blocks. + _, _, err = fs.LoadNextMsg("foo.*.bar", true, 4, nil) + require_Error(t, err, ErrStoreEOF) + + // Now make sure we did not load fss and cache. + var loaded int + fs.mu.RLock() + for _, mb := range fs.blks { + mb.mu.RLock() + if mb.cache != nil || mb.fss != nil { + loaded++ + } + mb.mu.RUnlock() + } + fs.mu.RUnlock() + // We will load first block for starting seq 4, but no others should have loaded. + require_Equal(t, loaded, 1) +} + /////////////////////////////////////////////////////////////////////////// // Benchmarks /////////////////////////////////////////////////////////////////////////// @@ -7419,3 +7471,32 @@ func Benchmark_FileStoreLoadNextMsgVerySparseMsgsInBetweenWithWildcard(b *testin require_NoError(b, err) } } + +func Benchmark_FileStoreLoadNextMsgVerySparseMsgsLargeTail(b *testing.B) { + fs, err := newFileStore( + FileStoreConfig{StoreDir: b.TempDir()}, + StreamConfig{Name: "zzz", Subjects: []string{"foo.*.*"}, Storage: FileStorage}) + require_NoError(b, err) + defer fs.Stop() + + // Small om purpose. + msg := []byte("ok") + + // Make first msg one that would match as well. + fs.StoreMsg("foo.1.baz", nil, msg) + // Add in a bunch of msgs. + // We need to make sure we have a range of subjects that could kick in a linear scan. + for i := 0; i < 1_000_000; i++ { + subj := fmt.Sprintf("foo.%d.bar", rand.Intn(100_000)+2) + fs.StoreMsg(subj, nil, msg) + } + + b.ResetTimer() + + var smv StoreMsg + for i := 0; i < b.N; i++ { + // Make sure not first seq. + _, _, err := fs.LoadNextMsg("foo.*.baz", true, 2, &smv) + require_Error(b, err, ErrStoreEOF) + } +} From d86cb30f2ca21f21c64e4e46315ec40778f4dced Mon Sep 17 00:00:00 2001 From: Neil Twigg Date: Mon, 24 Jun 2024 12:33:00 +0100 Subject: [PATCH 3/6] Add `node48` to stree A `node256` is nearly 4KB in memory whereas a `node48` is closer to 1KB. Signed-off-by: Neil Twigg --- server/stree/dump.go | 1 + server/stree/node16.go | 2 +- server/stree/node256.go | 4 +- server/stree/node48.go | 110 +++++++++++++++++++++++++++++++++++++ server/stree/stree_test.go | 110 +++++++++++++++++++++++++++++++++++-- 5 files changed, 220 insertions(+), 7 deletions(-) create mode 100644 server/stree/node48.go diff --git a/server/stree/dump.go b/server/stree/dump.go index 4a7d76fb586..60f03e4aad1 100644 --- a/server/stree/dump.go +++ b/server/stree/dump.go @@ -51,6 +51,7 @@ func (t *SubjectTree[T]) dump(w io.Writer, n node, depth int) { func (n *leaf[T]) kind() string { return "LEAF" } func (n *node4) kind() string { return "NODE4" } func (n *node16) kind() string { return "NODE16" } +func (n *node48) kind() string { return "NODE48" } func (n *node256) kind() string { return "NODE256" } // Calculates the indendation, etc. diff --git a/server/stree/node16.go b/server/stree/node16.go index 7da5df89d99..c0c12aafd57 100644 --- a/server/stree/node16.go +++ b/server/stree/node16.go @@ -50,7 +50,7 @@ func (n *node16) findChild(c byte) *node { func (n *node16) isFull() bool { return n.size >= 16 } func (n *node16) grow() node { - nn := newNode256(n.prefix) + nn := newNode48(n.prefix) for i := 0; i < 16; i++ { nn.addChild(n.key[i], n.child[i]) } diff --git a/server/stree/node256.go b/server/stree/node256.go index f5bf69bc93c..5d08b1487ab 100644 --- a/server/stree/node256.go +++ b/server/stree/node256.go @@ -51,10 +51,10 @@ func (n *node256) deleteChild(c byte) { // Shrink if needed and return new node, otherwise return nil. func (n *node256) shrink() node { - if n.size > 16 { + if n.size > 48 { return nil } - nn := newNode16(nil) + nn := newNode48(nil) for c, child := range n.child { if child != nil { nn.addChild(byte(c), n.child[c]) diff --git a/server/stree/node48.go b/server/stree/node48.go new file mode 100644 index 00000000000..fe7ef543529 --- /dev/null +++ b/server/stree/node48.go @@ -0,0 +1,110 @@ +// Copyright 2023-2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stree + +// Node with 48 children +// Memory saving vs node256 comes from the fact that the child array is 16 bytes +// per `node` entry, so node256's 256*16=4096 vs node48's 256+(48*16)=1024 +// Note that `key` is effectively 1-indexed, as 0 means no entry, so offset by 1 +// Order of struct fields for best memory alignment (as per govet/fieldalignment) +type node48 struct { + child [48]node + meta + key [256]byte +} + +func newNode48(prefix []byte) *node48 { + nn := &node48{} + nn.setPrefix(prefix) + return nn +} + +func (n *node48) addChild(c byte, nn node) { + if n.size >= 48 { + panic("node48 full!") + } + n.child[n.size] = nn + n.key[c] = byte(n.size + 1) // 1-indexed + n.size++ +} + +func (n *node48) findChild(c byte) *node { + i := n.key[c] + if i == 0 { + return nil + } + return &n.child[i-1] +} + +func (n *node48) isFull() bool { return n.size >= 48 } + +func (n *node48) grow() node { + nn := newNode256(n.prefix) + for c := byte(0); c < 255; c++ { + if i := n.key[c]; i > 0 { + nn.addChild(c, n.child[i-1]) + } + } + return nn +} + +// Deletes a child from the node. +func (n *node48) deleteChild(c byte) { + i := n.key[c] + if i == 0 { + return + } + i-- // Adjust for 1-indexing + last := byte(n.size - 1) + if i < last { + n.child[i] = n.child[last] + for c := byte(0); c <= 255; c++ { + if n.key[c] == last+1 { + n.key[c] = i + 1 + break + } + } + } + n.child[last] = nil + n.key[c] = 0 + n.size-- +} + +// Shrink if needed and return new node, otherwise return nil. +func (n *node48) shrink() node { + if n.size > 16 { + return nil + } + nn := newNode16(nil) + for c := byte(0); c < 255; c++ { + if i := n.key[c]; i > 0 { + nn.addChild(c, n.child[i-1]) + } + } + return nn +} + +// Iterate over all children calling func f. +func (n *node48) iter(f func(node) bool) { + for _, c := range n.child { + if c != nil && !f(c) { + return + } + } +} + +// Return our children as a slice. +func (n *node48) children() []node { + return n.child[:n.size] +} diff --git a/server/stree/stree_test.go b/server/stree/stree_test.go index 7421bcf6314..e6435b08c87 100644 --- a/server/stree/stree_test.go +++ b/server/stree/stree_test.go @@ -78,7 +78,6 @@ func TestSubjectTreeNodeGrow(t *testing.T) { require_False(t, updated) _, ok = st.root.(*node16) require_True(t, ok) - // We do not have node48, so once we fill this we should jump to node256. for i := 5; i < 16; i++ { subj := b(fmt.Sprintf("foo.bar.%c", 'A'+i)) old, updated := st.Insert(subj, 22) @@ -89,6 +88,20 @@ func TestSubjectTreeNodeGrow(t *testing.T) { old, updated = st.Insert(b("foo.bar.Q"), 22) require_True(t, old == nil) require_False(t, updated) + _, ok = st.root.(*node48) + require_True(t, ok) + // Fill the node48. + for i := 17; i < 48; i++ { + subj := b(fmt.Sprintf("foo.bar.%c", 'A'+i)) + old, updated := st.Insert(subj, 22) + require_True(t, old == nil) + require_False(t, updated) + } + // This one will trigger us to grow. + subj := b(fmt.Sprintf("foo.bar.%c", 'A'+49)) + old, updated = st.Insert(subj, 22) + require_True(t, old == nil) + require_False(t, updated) _, ok = st.root.(*node256) require_True(t, ok) } @@ -160,13 +173,13 @@ func TestSubjectTreeNodeDelete(t *testing.T) { require_Equal(t, *v, 22) _, ok = st.root.(*node4) require_True(t, ok) - // Now pop up to node256 + // Now pop up to node48 st = NewSubjectTree[int]() for i := 0; i < 17; i++ { subj := fmt.Sprintf("foo.bar.%c", 'A'+i) st.Insert(b(subj), 22) } - _, ok = st.root.(*node256) + _, ok = st.root.(*node48) require_True(t, ok) v, found = st.Delete(b("foo.bar.A")) require_True(t, found) @@ -176,6 +189,22 @@ func TestSubjectTreeNodeDelete(t *testing.T) { v, found = st.Find(b("foo.bar.B")) require_True(t, found) require_Equal(t, *v, 22) + // Now pop up to node256 + st = NewSubjectTree[int]() + for i := 0; i < 49; i++ { + subj := fmt.Sprintf("foo.bar.%c", 'A'+i) + st.Insert(b(subj), 22) + } + _, ok = st.root.(*node256) + require_True(t, ok) + v, found = st.Delete(b("foo.bar.A")) + require_True(t, found) + require_Equal(t, *v, 22) + _, ok = st.root.(*node48) + require_True(t, ok) + v, found = st.Find(b("foo.bar.B")) + require_True(t, found) + require_Equal(t, *v, 22) } func TestSubjectTreeNodesAndPaths(t *testing.T) { @@ -341,7 +370,7 @@ func TestSubjectTreeNoPrefix(t *testing.T) { require_True(t, old == nil) require_False(t, updated) } - n, ok := st.root.(*node256) + n, ok := st.root.(*node48) require_True(t, ok) require_Equal(t, n.numChildren(), 26) v, found := st.Delete(b("B")) @@ -636,3 +665,76 @@ func TestSubjectTreeMatchAllPerf(t *testing.T) { t.Logf("Match %q took %s and matched %d entries", f, time.Since(start), count) } } + +func TestSubjectTreeNode48(t *testing.T) { + var a, b, c leaf[int] + var n node48 + + n.addChild('A', &a) + require_Equal(t, n.key['A'], 1) + require_True(t, n.child[0] != nil) + require_Equal(t, n.child[0].(*leaf[int]), &a) + require_Equal(t, len(n.children()), 1) + + child := n.findChild('A') + require_True(t, child != nil) + require_Equal(t, (*child).(*leaf[int]), &a) + + n.addChild('B', &b) + require_Equal(t, n.key['B'], 2) + require_True(t, n.child[1] != nil) + require_Equal(t, n.child[1].(*leaf[int]), &b) + require_Equal(t, len(n.children()), 2) + + child = n.findChild('B') + require_True(t, child != nil) + require_Equal(t, (*child).(*leaf[int]), &b) + + n.addChild('C', &c) + require_Equal(t, n.key['C'], 3) + require_True(t, n.child[2] != nil) + require_Equal(t, n.child[2].(*leaf[int]), &c) + require_Equal(t, len(n.children()), 3) + + child = n.findChild('C') + require_True(t, child != nil) + require_Equal(t, (*child).(*leaf[int]), &c) + + n.deleteChild('A') + require_Equal(t, len(n.children()), 2) + require_Equal(t, n.key['A'], 0) // Now deleted + require_Equal(t, n.key['B'], 2) // Untouched + require_Equal(t, n.key['C'], 1) // Where A was + + child = n.findChild('A') + require_Equal(t, child, nil) + require_True(t, n.child[0] != nil) + require_Equal(t, n.child[0].(*leaf[int]), &c) + + child = n.findChild('B') + require_True(t, child != nil) + require_Equal(t, (*child).(*leaf[int]), &b) + require_True(t, n.child[1] != nil) + require_Equal(t, n.child[1].(*leaf[int]), &b) + + child = n.findChild('C') + require_True(t, child != nil) + require_Equal(t, (*child).(*leaf[int]), &c) + require_True(t, n.child[2] == nil) + + var gotB, gotC bool + var iterations int + n.iter(func(n node) bool { + iterations++ + if gb, ok := n.(*leaf[int]); ok && &b == gb { + gotB = true + } + if gc, ok := n.(*leaf[int]); ok && &c == gc { + gotC = true + } + return true + }) + require_Equal(t, iterations, 2) + require_True(t, gotB) + require_True(t, gotC) +} From ad8c97dd1cd11ba2d1d615fa0bbb7dedad81a44e Mon Sep 17 00:00:00 2001 From: Neil Twigg Date: Mon, 24 Jun 2024 17:53:30 +0100 Subject: [PATCH 4/6] NRG: Fix leaving observer state when applies are paused Signed-off-by: Neil Twigg --- server/raft.go | 8 ++++++++ server/raft_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/server/raft.go b/server/raft.go index 2e72f71d431..1e1bc680976 100644 --- a/server/raft.go +++ b/server/raft.go @@ -1886,6 +1886,14 @@ func (n *raft) setObserver(isObserver bool, extSt extensionState) { n.Lock() defer n.Unlock() + if n.paused { + // Applies are paused so we're already in observer state. + // Resuming the applies will set the state back to whatever + // is in "pobserver", so update that instead. + n.pobserver = isObserver + return + } + wasObserver := n.observer n.observer = isObserver n.extSt = extSt diff --git a/server/raft_test.go b/server/raft_test.go index 91226407583..beb15d63480 100644 --- a/server/raft_test.go +++ b/server/raft_test.go @@ -549,3 +549,42 @@ func TestNRGSystemClientCleanupFromAccount(t *testing.T) { finish := numClients() require_Equal(t, start, finish) } + +func TestNRGLeavesObserverAfterPause(t *testing.T) { + c := createJetStreamClusterExplicit(t, "R3S", 3) + defer c.shutdown() + + rg := c.createMemRaftGroup("TEST", 3, newStateAdder) + rg.waitOnLeader() + + n := rg.nonLeader().node().(*raft) + + checkState := func(observer, pobserver bool) { + t.Helper() + n.RLock() + defer n.RUnlock() + require_Equal(t, n.observer, observer) + require_Equal(t, n.pobserver, pobserver) + } + + // Assume this has happened because of jetstream_cluster_migrate + // or similar. + n.SetObserver(true) + checkState(true, false) + + // Now something like a catchup has started, but since we were + // already in observer mode, pobserver is set to true. + n.PauseApply() + checkState(true, true) + + // Now jetstream_cluster_migrate is happy that the leafnodes are + // back up so it tries to leave observer mode, but the catchup + // hasn't finished yet. This will instead set pobserver to false. + n.SetObserver(false) + checkState(true, false) + + // The catchup finishes, so we should correctly leave the observer + // state by setting observer to the pobserver value. + n.ResumeApply() + checkState(false, false) +} From b68ce834a4549a24de7b93729a7f13c86cda3713 Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Mon, 24 Jun 2024 10:32:21 -0700 Subject: [PATCH 5/6] Allow kick to work on leafnodes as well. Signed-off-by: Derek Collison --- server/leafnode_test.go | 44 +++++++++++++++++++++++++++++++++++++++++ server/server.go | 5 ++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/server/leafnode_test.go b/server/leafnode_test.go index 59563c2a7f1..2d8f2d83ecc 100644 --- a/server/leafnode_test.go +++ b/server/leafnode_test.go @@ -7715,3 +7715,47 @@ func TestLeafNodeDupeDeliveryQueueSubAndPlainSub(t *testing.T) { require_NoError(t, err) require_Equal(t, n, 1) } + +func TestLeafNodeServerKickClient(t *testing.T) { + stmpl := ` + listen: 127.0.0.1:-1 + server_name: test-server + leaf { listen: 127.0.0.1:-1 } + ` + conf := createConfFile(t, []byte(stmpl)) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + tmpl := ` + listen: 127.0.0.1:-1 + server_name: test-leaf + leaf { remotes: [ { urls: [ nats-leaf://127.0.0.1:{LEAF_PORT} ] } ] } + ` + tmpl = strings.Replace(tmpl, "{LEAF_PORT}", fmt.Sprintf("%d", o.LeafNode.Port), 1) + lConf := createConfFile(t, []byte(tmpl)) + l, _ := RunServerWithConfig(lConf) + defer l.Shutdown() + + checkLeafNodeConnected(t, l) + + // We want to make sure we can kick the leafnode connections as well as client connections. + conns, err := s.Connz(&ConnzOptions{Account: globalAccountName}) + require_NoError(t, err) + require_Equal(t, len(conns.Conns), 1) + lid := conns.Conns[0].Cid + + disconnectTime := time.Now() + err = s.DisconnectClientByID(lid) + require_NoError(t, err) + + // Wait until we are reconnected. + checkLeafNodeConnected(t, s) + + // Look back up again and make sure start time indicates a restart, meaning kick worked. + conns, err = s.Connz(&ConnzOptions{Account: globalAccountName}) + require_NoError(t, err) + require_Equal(t, len(conns.Conns), 1) + ln := conns.Conns[0] + require_True(t, lid != ln.Cid) + require_True(t, ln.Start.After(disconnectTime)) +} diff --git a/server/server.go b/server/server.go index 5f2c3e76dbb..62a09fad25d 100644 --- a/server/server.go +++ b/server/server.go @@ -4422,8 +4422,11 @@ func (s *Server) DisconnectClientByID(id uint64) error { if client := s.getClient(id); client != nil { client.closeConnection(Kicked) return nil + } else if client = s.GetLeafNode(id); client != nil { + client.closeConnection(Kicked) + return nil } - return errors.New("no such client id") + return errors.New("no such client or leafnode id") } // LDMClientByID sends a Lame Duck Mode info message to a client by connection ID From e019d7abcc28692b2e4f6dc6b825da4ed290f412 Mon Sep 17 00:00:00 2001 From: Waldemar Quevedo Date: Mon, 24 Jun 2024 22:05:45 -0700 Subject: [PATCH 6/6] Fix imports not being available for a client sometimes after a server restart (#5589) When a client would reconnect to a server that was still setting up the imports and exports a client that reconnected too soon to a server that had just been restarted might be missing some of the imports that were defined in its JWT. --------- Signed-off-by: Waldemar Quevedo Signed-off-by: Derek Collison Co-authored-by: Derek Collison --- server/accounts.go | 66 ++++++++------- server/client.go | 5 +- server/jwt_test.go | 195 ++++++++++++++++++++++++++++++++++++++++++++- server/server.go | 4 +- 4 files changed, 235 insertions(+), 35 deletions(-) diff --git a/server/accounts.go b/server/accounts.go index 2c90421b3d0..4b24903a0d1 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -2822,9 +2822,12 @@ func (a *Account) isIssuerClaimTrusted(claims *jwt.ActivationClaims) bool { // check is done with the account's name, not the pointer. This is used // during config reload where we are comparing current and new config // in which pointers are different. -// No lock is acquired in this function, so it is assumed that the -// import maps are not changed while this executes. +// Acquires `a` read lock, but `b` is assumed to not be accessed +// by anyone but the caller (`b` is not registered anywhere). func (a *Account) checkStreamImportsEqual(b *Account) bool { + a.mu.RLock() + defer a.mu.RUnlock() + if len(a.imports.streams) != len(b.imports.streams) { return false } @@ -3192,6 +3195,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim a.nameTag = ac.Name a.tags = ac.Tags + // Grab trace label under lock. + tl := a.traceLabel() + // Check for external authorization. if ac.HasExternalAuthorization() { a.extAuth = &jwt.ExternalAuthorization{} @@ -3212,10 +3218,10 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } if a.imports.services != nil { old.imports.services = make(map[string]*serviceImport, len(a.imports.services)) - } - for k, v := range a.imports.services { - old.imports.services[k] = v - delete(a.imports.services, k) + for k, v := range a.imports.services { + old.imports.services[k] = v + delete(a.imports.services, k) + } } alteredScope := map[string]struct{}{} @@ -3285,13 +3291,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim for _, e := range ac.Exports { switch e.Type { case jwt.Stream: - s.Debugf("Adding stream export %q for %s", e.Subject, a.traceLabel()) + s.Debugf("Adding stream export %q for %s", e.Subject, tl) if err := a.addStreamExportWithAccountPos( string(e.Subject), authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil { - s.Debugf("Error adding stream export to account [%s]: %v", a.traceLabel(), err.Error()) + s.Debugf("Error adding stream export to account [%s]: %v", tl, err.Error()) } case jwt.Service: - s.Debugf("Adding service export %q for %s", e.Subject, a.traceLabel()) + s.Debugf("Adding service export %q for %s", e.Subject, tl) rt := Singleton switch e.ResponseType { case jwt.ResponseTypeStream: @@ -3301,7 +3307,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } if err := a.addServiceExportWithResponseAndAccountPos( string(e.Subject), rt, authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil { - s.Debugf("Error adding service export to account [%s]: %v", a.traceLabel(), err) + s.Debugf("Error adding service export to account [%s]: %v", tl, err) continue } sub := string(e.Subject) @@ -3311,13 +3317,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim if e.Latency.Sampling == jwt.Headers { hdrNote = " (using headers)" } - s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, a.traceLabel(), err) + s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, tl, err) } } if e.ResponseThreshold != 0 { // Response threshold was set in options. if err := a.SetServiceExportResponseThreshold(sub, e.ResponseThreshold); err != nil { - s.Debugf("Error adding service export response threshold for [%s]: %v", a.traceLabel(), err) + s.Debugf("Error adding service export response threshold for [%s]: %v", tl, err) } } } @@ -3362,34 +3368,31 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } var incompleteImports []*jwt.Import for _, i := range ac.Imports { - // check tmpAccounts with priority - var acc *Account - var err error - if v, ok := s.tmpAccounts.Load(i.Account); ok { - acc = v.(*Account) - } else { - acc, err = s.lookupAccount(i.Account) - } + acc, err := s.lookupAccount(i.Account) if acc == nil || err != nil { s.Errorf("Can't locate account [%s] for import of [%v] %s (err=%v)", i.Account, i.Subject, i.Type, err) incompleteImports = append(incompleteImports, i) continue } - from := string(i.Subject) - to := i.GetTo() + // Capture trace labels. + acc.mu.RLock() + atl := acc.traceLabel() + acc.mu.RUnlock() + // Grab from and to + from, to := string(i.Subject), i.GetTo() switch i.Type { case jwt.Stream: if i.LocalSubject != _EMPTY_ { // set local subject implies to is empty to = string(i.LocalSubject) - s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to) + s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to) err = a.AddMappedStreamImportWithClaim(acc, from, to, i) } else { - s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to) + s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to) err = a.AddStreamImportWithClaim(acc, from, to, i) } if err != nil { - s.Debugf("Error adding stream import to account [%s]: %v", a.traceLabel(), err.Error()) + s.Debugf("Error adding stream import to account [%s]: %v", tl, err.Error()) incompleteImports = append(incompleteImports, i) } case jwt.Service: @@ -3397,9 +3400,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim from = string(i.LocalSubject) to = string(i.Subject) } - s.Debugf("Adding service import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to) + s.Debugf("Adding service import %s:%q for %s:%q", atl, from, tl, to) if err := a.AddServiceImportWithClaim(acc, from, to, i); err != nil { - s.Debugf("Error adding service import to account [%s]: %v", a.traceLabel(), err.Error()) + s.Debugf("Error adding service import to account [%s]: %v", tl, err.Error()) incompleteImports = append(incompleteImports, i) } } @@ -3570,7 +3573,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim // regardless of enabled or disabled. It handles both cases. if jsEnabled { if err := s.configJetStream(a); err != nil { - s.Errorf("Error configuring jetstream for account [%s]: %v", a.traceLabel(), err.Error()) + s.Errorf("Error configuring jetstream for account [%s]: %v", tl, err.Error()) a.mu.Lock() // Absent reload of js server cfg, this is going to be broken until js is disabled a.incomplete = true @@ -3707,8 +3710,13 @@ func (s *Server) buildInternalAccount(ac *jwt.AccountClaims) *Account { // We don't want to register an account that is in the process of // being built, however, to solve circular import dependencies, we // need to store it here. - s.tmpAccounts.Store(ac.Subject, acc) + if v, loaded := s.tmpAccounts.LoadOrStore(ac.Subject, acc); loaded { + return v.(*Account) + } + + // Update based on claims. s.UpdateAccountClaims(acc, ac) + return acc } diff --git a/server/client.go b/server/client.go index 3dd0ce6dc66..619db25aad0 100644 --- a/server/client.go +++ b/server/client.go @@ -2914,8 +2914,11 @@ func (c *client) addShadowSubscriptions(acc *Account, sub *subscription, enact b // Add in the shadow subscription. func (c *client) addShadowSub(sub *subscription, ime *ime, enact bool) (*subscription, error) { - im := ime.im + c.mu.Lock() nsub := *sub // copy + c.mu.Unlock() + + im := ime.im nsub.im = im if !im.usePub && ime.dyn && im.tr != nil { diff --git a/server/jwt_test.go b/server/jwt_test.go index ab7de041769..d6cdd1dd7d9 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -15,6 +15,7 @@ package server import ( "bufio" + "context" "encoding/base64" "encoding/json" "errors" @@ -1991,9 +1992,9 @@ func TestJWTAccountURLResolverPermanentFetchFailure(t *testing.T) { importErrCnt++ } case <-tmr.C: - // connecting and updating, each cause 3 traces (2 + 1 on iteration) - if importErrCnt != 6 { - t.Fatalf("Expected 6 debug traces, got %d", importErrCnt) + // connecting and updating, each cause 3 traces (2 + 1 on iteration) + 1 xtra fetch + if importErrCnt != 7 { + t.Fatalf("Expected 7 debug traces, got %d", importErrCnt) } return } @@ -6842,3 +6843,191 @@ func TestJWTAccountNATSResolverWrongCreds(t *testing.T) { t.Fatalf("Expected auth error: %v", err) } } + +// Issue 5480: https://github.com/nats-io/nats-server/issues/5480 +func TestJWTImportsOnServerRestartAndClientsReconnect(t *testing.T) { + type namedCreds struct { + name string + creds nats.Option + } + preload := make(map[string]string) + users := make(map[string]*namedCreds) + + // sys account + _, sysAcc, sysAccClaim := NewJwtAccountClaim("sys") + sysAccJWT, err := sysAccClaim.Encode(oKp) + require_NoError(t, err) + preload[sysAcc] = sysAccJWT + + // main account, other accounts will import from this. + mainAccKP, mainAcc, mainAccClaim := NewJwtAccountClaim("main") + mainAccClaim.Exports.Add(&jwt.Export{ + Type: jwt.Stream, + Subject: "city.>", + }) + + // main account user + mainUserClaim := jwt.NewUserClaims("publisher") + mainUserClaim.Permissions = jwt.Permissions{ + Pub: jwt.Permission{ + Allow: []string{"city.>"}, + }, + } + mainCreds := createUserCredsEx(t, mainUserClaim, mainAccKP) + + // The main account will be importing from all other accounts. + maxAccounts := 100 + for i := 0; i < maxAccounts; i++ { + name := fmt.Sprintf("secondary-%d", i) + accKP, acc, accClaim := NewJwtAccountClaim(name) + + accClaim.Exports.Add(&jwt.Export{ + Type: jwt.Stream, + Subject: "internal.*", + }) + accClaim.Imports.Add(&jwt.Import{ + Type: jwt.Stream, + Subject: jwt.Subject(fmt.Sprintf("city.%d-1.*", i)), + Account: mainAcc, + }) + + // main account imports from the secondary accounts + mainAccClaim.Imports.Add(&jwt.Import{ + Type: jwt.Stream, + Subject: jwt.Subject(fmt.Sprintf("internal.%d", i)), + Account: acc, + }) + + accJWT, err := accClaim.Encode(oKp) + require_NoError(t, err) + preload[acc] = accJWT + + userClaim := jwt.NewUserClaims("subscriber") + userClaim.Permissions = jwt.Permissions{ + Sub: jwt.Permission{ + Allow: []string{"city.>", "internal.*"}, + }, + Pub: jwt.Permission{ + Allow: []string{"internal.*"}, + }, + } + userCreds := createUserCredsEx(t, userClaim, accKP) + users[acc] = &namedCreds{name, userCreds} + } + mainAccJWT, err := mainAccClaim.Encode(oKp) + require_NoError(t, err) + preload[mainAcc] = mainAccJWT + + // Start the server with the preload. + resolverPreload, err := json.Marshal(preload) + require_NoError(t, err) + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:4747 + http: 127.0.0.1:8222 + operator: %s + system_account: %s + resolver: MEM + resolver_preload: %s + `, ojwt, sysAcc, string(resolverPreload)))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + // Have a connection ready for each one of the accounts. + type namedSub struct { + name string + sub *nats.Subscription + } + subs := make(map[string]*namedSub) + for acc, user := range users { + nc := natsConnect(t, s.ClientURL(), user.creds, + // Make the clients attempt to reconnect too fast, + // changing this to be above ~200ms mitigates the issue. + nats.ReconnectWait(15*time.Millisecond), + nats.Name(user.name), + nats.MaxReconnects(-1), + ) + defer nc.Close() + + sub, err := nc.SubscribeSync("city.>") + require_NoError(t, err) + subs[acc] = &namedSub{user.name, sub} + } + + nc := natsConnect(t, s.ClientURL(), mainCreds, nats.ReconnectWait(15*time.Millisecond), nats.MaxReconnects(-1)) + defer nc.Close() + + send := func(t *testing.T) { + t.Helper() + for i := 0; i < maxAccounts; i++ { + nc.Publish(fmt.Sprintf("city.%d-1.A4BDB048-69DC-4F10-916C-2B998249DC11", i), []byte(fmt.Sprintf("test:%d", i))) + } + nc.Flush() + } + + ctx, done := context.WithCancel(context.Background()) + defer done() + go func() { + for range time.NewTicker(200 * time.Millisecond).C { + select { + case <-ctx.Done(): + default: + } + send(t) + } + }() + + receive := func(t *testing.T) { + t.Helper() + received := 0 + for _, nsub := range subs { + // Drain first any pending messages. + pendingMsgs, _, _ := nsub.sub.Pending() + for i, _ := 0, 0; i < pendingMsgs; i++ { + nsub.sub.NextMsg(500 * time.Millisecond) + } + + _, err = nsub.sub.NextMsg(500 * time.Millisecond) + if err != nil { + t.Logf("WRN: Failed to receive message on account %q: %v", nsub.name, err) + } else { + received++ + } + } + if received < (maxAccounts / 2) { + t.Fatalf("Too many missed messages after restart. Received %d", received) + } + } + receive(t) + time.Sleep(1 * time.Second) + + restart := func(t *testing.T) *Server { + t.Helper() + s.Shutdown() + s.WaitForShutdown() + s, _ = RunServerWithConfig(conf) + + hctx, hcancel := context.WithTimeout(context.Background(), 5*time.Second) + defer hcancel() + for range time.NewTicker(2 * time.Second).C { + select { + case <-hctx.Done(): + t.Logf("WRN: Timed out waiting for healthz from %s", s) + default: + } + + status := s.healthz(nil) + if status.StatusCode == 200 { + return s + } + } + return nil + } + + // Takes a few restarts for issue to show up. + for i := 0; i < 5; i++ { + s := restart(t) + defer s.Shutdown() + time.Sleep(2 * time.Second) + receive(t) + } +} diff --git a/server/server.go b/server/server.go index 62a09fad25d..8a276958986 100644 --- a/server/server.go +++ b/server/server.go @@ -1097,11 +1097,11 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error) if reloading && acc.Name != globalAccountName { if ai, ok := s.accounts.Load(acc.Name); ok { a = ai.(*Account) - a.mu.Lock() // Before updating the account, check if stream imports have changed. if !a.checkStreamImportsEqual(acc) { awcsti[acc.Name] = struct{}{} } + a.mu.Lock() // Collect the sids for the service imports since we are going to // replace with new ones. var sids [][]byte @@ -2064,7 +2064,6 @@ func (s *Server) fetchAccount(name string) (*Account, error) { return nil, err } acc := s.buildInternalAccount(accClaims) - acc.claimJWT = claimJWT // Due to possible race, if registerAccount() returns a non // nil account, it means the same account was already // registered and we should use this one. @@ -2080,6 +2079,7 @@ func (s *Server) fetchAccount(name string) (*Account, error) { var needImportSubs bool acc.mu.Lock() + acc.claimJWT = claimJWT if len(acc.imports.services) > 0 { if acc.ic == nil { acc.ic = s.createInternalAccountClient()