From e019d7abcc28692b2e4f6dc6b825da4ed290f412 Mon Sep 17 00:00:00 2001 From: Waldemar Quevedo Date: Mon, 24 Jun 2024 22:05:45 -0700 Subject: [PATCH] Fix imports not being available for a client sometimes after a server restart (#5589) When a client would reconnect to a server that was still setting up the imports and exports a client that reconnected too soon to a server that had just been restarted might be missing some of the imports that were defined in its JWT. --------- Signed-off-by: Waldemar Quevedo Signed-off-by: Derek Collison Co-authored-by: Derek Collison --- server/accounts.go | 66 ++++++++------- server/client.go | 5 +- server/jwt_test.go | 195 ++++++++++++++++++++++++++++++++++++++++++++- server/server.go | 4 +- 4 files changed, 235 insertions(+), 35 deletions(-) diff --git a/server/accounts.go b/server/accounts.go index 2c90421b3d0..4b24903a0d1 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -2822,9 +2822,12 @@ func (a *Account) isIssuerClaimTrusted(claims *jwt.ActivationClaims) bool { // check is done with the account's name, not the pointer. This is used // during config reload where we are comparing current and new config // in which pointers are different. -// No lock is acquired in this function, so it is assumed that the -// import maps are not changed while this executes. +// Acquires `a` read lock, but `b` is assumed to not be accessed +// by anyone but the caller (`b` is not registered anywhere). func (a *Account) checkStreamImportsEqual(b *Account) bool { + a.mu.RLock() + defer a.mu.RUnlock() + if len(a.imports.streams) != len(b.imports.streams) { return false } @@ -3192,6 +3195,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim a.nameTag = ac.Name a.tags = ac.Tags + // Grab trace label under lock. + tl := a.traceLabel() + // Check for external authorization. if ac.HasExternalAuthorization() { a.extAuth = &jwt.ExternalAuthorization{} @@ -3212,10 +3218,10 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } if a.imports.services != nil { old.imports.services = make(map[string]*serviceImport, len(a.imports.services)) - } - for k, v := range a.imports.services { - old.imports.services[k] = v - delete(a.imports.services, k) + for k, v := range a.imports.services { + old.imports.services[k] = v + delete(a.imports.services, k) + } } alteredScope := map[string]struct{}{} @@ -3285,13 +3291,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim for _, e := range ac.Exports { switch e.Type { case jwt.Stream: - s.Debugf("Adding stream export %q for %s", e.Subject, a.traceLabel()) + s.Debugf("Adding stream export %q for %s", e.Subject, tl) if err := a.addStreamExportWithAccountPos( string(e.Subject), authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil { - s.Debugf("Error adding stream export to account [%s]: %v", a.traceLabel(), err.Error()) + s.Debugf("Error adding stream export to account [%s]: %v", tl, err.Error()) } case jwt.Service: - s.Debugf("Adding service export %q for %s", e.Subject, a.traceLabel()) + s.Debugf("Adding service export %q for %s", e.Subject, tl) rt := Singleton switch e.ResponseType { case jwt.ResponseTypeStream: @@ -3301,7 +3307,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } if err := a.addServiceExportWithResponseAndAccountPos( string(e.Subject), rt, authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil { - s.Debugf("Error adding service export to account [%s]: %v", a.traceLabel(), err) + s.Debugf("Error adding service export to account [%s]: %v", tl, err) continue } sub := string(e.Subject) @@ -3311,13 +3317,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim if e.Latency.Sampling == jwt.Headers { hdrNote = " (using headers)" } - s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, a.traceLabel(), err) + s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, tl, err) } } if e.ResponseThreshold != 0 { // Response threshold was set in options. if err := a.SetServiceExportResponseThreshold(sub, e.ResponseThreshold); err != nil { - s.Debugf("Error adding service export response threshold for [%s]: %v", a.traceLabel(), err) + s.Debugf("Error adding service export response threshold for [%s]: %v", tl, err) } } } @@ -3362,34 +3368,31 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } var incompleteImports []*jwt.Import for _, i := range ac.Imports { - // check tmpAccounts with priority - var acc *Account - var err error - if v, ok := s.tmpAccounts.Load(i.Account); ok { - acc = v.(*Account) - } else { - acc, err = s.lookupAccount(i.Account) - } + acc, err := s.lookupAccount(i.Account) if acc == nil || err != nil { s.Errorf("Can't locate account [%s] for import of [%v] %s (err=%v)", i.Account, i.Subject, i.Type, err) incompleteImports = append(incompleteImports, i) continue } - from := string(i.Subject) - to := i.GetTo() + // Capture trace labels. + acc.mu.RLock() + atl := acc.traceLabel() + acc.mu.RUnlock() + // Grab from and to + from, to := string(i.Subject), i.GetTo() switch i.Type { case jwt.Stream: if i.LocalSubject != _EMPTY_ { // set local subject implies to is empty to = string(i.LocalSubject) - s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to) + s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to) err = a.AddMappedStreamImportWithClaim(acc, from, to, i) } else { - s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to) + s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to) err = a.AddStreamImportWithClaim(acc, from, to, i) } if err != nil { - s.Debugf("Error adding stream import to account [%s]: %v", a.traceLabel(), err.Error()) + s.Debugf("Error adding stream import to account [%s]: %v", tl, err.Error()) incompleteImports = append(incompleteImports, i) } case jwt.Service: @@ -3397,9 +3400,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim from = string(i.LocalSubject) to = string(i.Subject) } - s.Debugf("Adding service import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to) + s.Debugf("Adding service import %s:%q for %s:%q", atl, from, tl, to) if err := a.AddServiceImportWithClaim(acc, from, to, i); err != nil { - s.Debugf("Error adding service import to account [%s]: %v", a.traceLabel(), err.Error()) + s.Debugf("Error adding service import to account [%s]: %v", tl, err.Error()) incompleteImports = append(incompleteImports, i) } } @@ -3570,7 +3573,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim // regardless of enabled or disabled. It handles both cases. if jsEnabled { if err := s.configJetStream(a); err != nil { - s.Errorf("Error configuring jetstream for account [%s]: %v", a.traceLabel(), err.Error()) + s.Errorf("Error configuring jetstream for account [%s]: %v", tl, err.Error()) a.mu.Lock() // Absent reload of js server cfg, this is going to be broken until js is disabled a.incomplete = true @@ -3707,8 +3710,13 @@ func (s *Server) buildInternalAccount(ac *jwt.AccountClaims) *Account { // We don't want to register an account that is in the process of // being built, however, to solve circular import dependencies, we // need to store it here. - s.tmpAccounts.Store(ac.Subject, acc) + if v, loaded := s.tmpAccounts.LoadOrStore(ac.Subject, acc); loaded { + return v.(*Account) + } + + // Update based on claims. s.UpdateAccountClaims(acc, ac) + return acc } diff --git a/server/client.go b/server/client.go index 3dd0ce6dc66..619db25aad0 100644 --- a/server/client.go +++ b/server/client.go @@ -2914,8 +2914,11 @@ func (c *client) addShadowSubscriptions(acc *Account, sub *subscription, enact b // Add in the shadow subscription. func (c *client) addShadowSub(sub *subscription, ime *ime, enact bool) (*subscription, error) { - im := ime.im + c.mu.Lock() nsub := *sub // copy + c.mu.Unlock() + + im := ime.im nsub.im = im if !im.usePub && ime.dyn && im.tr != nil { diff --git a/server/jwt_test.go b/server/jwt_test.go index ab7de041769..d6cdd1dd7d9 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -15,6 +15,7 @@ package server import ( "bufio" + "context" "encoding/base64" "encoding/json" "errors" @@ -1991,9 +1992,9 @@ func TestJWTAccountURLResolverPermanentFetchFailure(t *testing.T) { importErrCnt++ } case <-tmr.C: - // connecting and updating, each cause 3 traces (2 + 1 on iteration) - if importErrCnt != 6 { - t.Fatalf("Expected 6 debug traces, got %d", importErrCnt) + // connecting and updating, each cause 3 traces (2 + 1 on iteration) + 1 xtra fetch + if importErrCnt != 7 { + t.Fatalf("Expected 7 debug traces, got %d", importErrCnt) } return } @@ -6842,3 +6843,191 @@ func TestJWTAccountNATSResolverWrongCreds(t *testing.T) { t.Fatalf("Expected auth error: %v", err) } } + +// Issue 5480: https://github.com/nats-io/nats-server/issues/5480 +func TestJWTImportsOnServerRestartAndClientsReconnect(t *testing.T) { + type namedCreds struct { + name string + creds nats.Option + } + preload := make(map[string]string) + users := make(map[string]*namedCreds) + + // sys account + _, sysAcc, sysAccClaim := NewJwtAccountClaim("sys") + sysAccJWT, err := sysAccClaim.Encode(oKp) + require_NoError(t, err) + preload[sysAcc] = sysAccJWT + + // main account, other accounts will import from this. + mainAccKP, mainAcc, mainAccClaim := NewJwtAccountClaim("main") + mainAccClaim.Exports.Add(&jwt.Export{ + Type: jwt.Stream, + Subject: "city.>", + }) + + // main account user + mainUserClaim := jwt.NewUserClaims("publisher") + mainUserClaim.Permissions = jwt.Permissions{ + Pub: jwt.Permission{ + Allow: []string{"city.>"}, + }, + } + mainCreds := createUserCredsEx(t, mainUserClaim, mainAccKP) + + // The main account will be importing from all other accounts. + maxAccounts := 100 + for i := 0; i < maxAccounts; i++ { + name := fmt.Sprintf("secondary-%d", i) + accKP, acc, accClaim := NewJwtAccountClaim(name) + + accClaim.Exports.Add(&jwt.Export{ + Type: jwt.Stream, + Subject: "internal.*", + }) + accClaim.Imports.Add(&jwt.Import{ + Type: jwt.Stream, + Subject: jwt.Subject(fmt.Sprintf("city.%d-1.*", i)), + Account: mainAcc, + }) + + // main account imports from the secondary accounts + mainAccClaim.Imports.Add(&jwt.Import{ + Type: jwt.Stream, + Subject: jwt.Subject(fmt.Sprintf("internal.%d", i)), + Account: acc, + }) + + accJWT, err := accClaim.Encode(oKp) + require_NoError(t, err) + preload[acc] = accJWT + + userClaim := jwt.NewUserClaims("subscriber") + userClaim.Permissions = jwt.Permissions{ + Sub: jwt.Permission{ + Allow: []string{"city.>", "internal.*"}, + }, + Pub: jwt.Permission{ + Allow: []string{"internal.*"}, + }, + } + userCreds := createUserCredsEx(t, userClaim, accKP) + users[acc] = &namedCreds{name, userCreds} + } + mainAccJWT, err := mainAccClaim.Encode(oKp) + require_NoError(t, err) + preload[mainAcc] = mainAccJWT + + // Start the server with the preload. + resolverPreload, err := json.Marshal(preload) + require_NoError(t, err) + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:4747 + http: 127.0.0.1:8222 + operator: %s + system_account: %s + resolver: MEM + resolver_preload: %s + `, ojwt, sysAcc, string(resolverPreload)))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + // Have a connection ready for each one of the accounts. + type namedSub struct { + name string + sub *nats.Subscription + } + subs := make(map[string]*namedSub) + for acc, user := range users { + nc := natsConnect(t, s.ClientURL(), user.creds, + // Make the clients attempt to reconnect too fast, + // changing this to be above ~200ms mitigates the issue. + nats.ReconnectWait(15*time.Millisecond), + nats.Name(user.name), + nats.MaxReconnects(-1), + ) + defer nc.Close() + + sub, err := nc.SubscribeSync("city.>") + require_NoError(t, err) + subs[acc] = &namedSub{user.name, sub} + } + + nc := natsConnect(t, s.ClientURL(), mainCreds, nats.ReconnectWait(15*time.Millisecond), nats.MaxReconnects(-1)) + defer nc.Close() + + send := func(t *testing.T) { + t.Helper() + for i := 0; i < maxAccounts; i++ { + nc.Publish(fmt.Sprintf("city.%d-1.A4BDB048-69DC-4F10-916C-2B998249DC11", i), []byte(fmt.Sprintf("test:%d", i))) + } + nc.Flush() + } + + ctx, done := context.WithCancel(context.Background()) + defer done() + go func() { + for range time.NewTicker(200 * time.Millisecond).C { + select { + case <-ctx.Done(): + default: + } + send(t) + } + }() + + receive := func(t *testing.T) { + t.Helper() + received := 0 + for _, nsub := range subs { + // Drain first any pending messages. + pendingMsgs, _, _ := nsub.sub.Pending() + for i, _ := 0, 0; i < pendingMsgs; i++ { + nsub.sub.NextMsg(500 * time.Millisecond) + } + + _, err = nsub.sub.NextMsg(500 * time.Millisecond) + if err != nil { + t.Logf("WRN: Failed to receive message on account %q: %v", nsub.name, err) + } else { + received++ + } + } + if received < (maxAccounts / 2) { + t.Fatalf("Too many missed messages after restart. Received %d", received) + } + } + receive(t) + time.Sleep(1 * time.Second) + + restart := func(t *testing.T) *Server { + t.Helper() + s.Shutdown() + s.WaitForShutdown() + s, _ = RunServerWithConfig(conf) + + hctx, hcancel := context.WithTimeout(context.Background(), 5*time.Second) + defer hcancel() + for range time.NewTicker(2 * time.Second).C { + select { + case <-hctx.Done(): + t.Logf("WRN: Timed out waiting for healthz from %s", s) + default: + } + + status := s.healthz(nil) + if status.StatusCode == 200 { + return s + } + } + return nil + } + + // Takes a few restarts for issue to show up. + for i := 0; i < 5; i++ { + s := restart(t) + defer s.Shutdown() + time.Sleep(2 * time.Second) + receive(t) + } +} diff --git a/server/server.go b/server/server.go index 62a09fad25d..8a276958986 100644 --- a/server/server.go +++ b/server/server.go @@ -1097,11 +1097,11 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error) if reloading && acc.Name != globalAccountName { if ai, ok := s.accounts.Load(acc.Name); ok { a = ai.(*Account) - a.mu.Lock() // Before updating the account, check if stream imports have changed. if !a.checkStreamImportsEqual(acc) { awcsti[acc.Name] = struct{}{} } + a.mu.Lock() // Collect the sids for the service imports since we are going to // replace with new ones. var sids [][]byte @@ -2064,7 +2064,6 @@ func (s *Server) fetchAccount(name string) (*Account, error) { return nil, err } acc := s.buildInternalAccount(accClaims) - acc.claimJWT = claimJWT // Due to possible race, if registerAccount() returns a non // nil account, it means the same account was already // registered and we should use this one. @@ -2080,6 +2079,7 @@ func (s *Server) fetchAccount(name string) (*Account, error) { var needImportSubs bool acc.mu.Lock() + acc.claimJWT = claimJWT if len(acc.imports.services) > 0 { if acc.ic == nil { acc.ic = s.createInternalAccountClient()