Skip to content

Commit

Permalink
Fix imports not being available for a client sometimes after a server…
Browse files Browse the repository at this point in the history
… restart (#5589)

When a client would reconnect to a server that was still setting up the
imports and exports a client that reconnected too soon to a server that
had just been restarted might be missing some of the imports that were
defined in its JWT.

---------

Signed-off-by: Waldemar Quevedo <wally@nats.io>
Signed-off-by: Derek Collison <derek@nats.io>
Co-authored-by: Derek Collison <derek@nats.io>
  • Loading branch information
2 people authored and neilalexander committed Jun 25, 2024
1 parent b68ce83 commit e019d7a
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 35 deletions.
66 changes: 37 additions & 29 deletions server/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2822,9 +2822,12 @@ func (a *Account) isIssuerClaimTrusted(claims *jwt.ActivationClaims) bool {
// check is done with the account's name, not the pointer. This is used
// during config reload where we are comparing current and new config
// in which pointers are different.
// No lock is acquired in this function, so it is assumed that the
// import maps are not changed while this executes.
// Acquires `a` read lock, but `b` is assumed to not be accessed
// by anyone but the caller (`b` is not registered anywhere).
func (a *Account) checkStreamImportsEqual(b *Account) bool {
a.mu.RLock()
defer a.mu.RUnlock()

if len(a.imports.streams) != len(b.imports.streams) {
return false
}
Expand Down Expand Up @@ -3192,6 +3195,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
a.nameTag = ac.Name
a.tags = ac.Tags

// Grab trace label under lock.
tl := a.traceLabel()

// Check for external authorization.
if ac.HasExternalAuthorization() {
a.extAuth = &jwt.ExternalAuthorization{}
Expand All @@ -3212,10 +3218,10 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}
if a.imports.services != nil {
old.imports.services = make(map[string]*serviceImport, len(a.imports.services))
}
for k, v := range a.imports.services {
old.imports.services[k] = v
delete(a.imports.services, k)
for k, v := range a.imports.services {
old.imports.services[k] = v
delete(a.imports.services, k)
}
}

alteredScope := map[string]struct{}{}
Expand Down Expand Up @@ -3285,13 +3291,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
for _, e := range ac.Exports {
switch e.Type {
case jwt.Stream:
s.Debugf("Adding stream export %q for %s", e.Subject, a.traceLabel())
s.Debugf("Adding stream export %q for %s", e.Subject, tl)
if err := a.addStreamExportWithAccountPos(
string(e.Subject), authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil {
s.Debugf("Error adding stream export to account [%s]: %v", a.traceLabel(), err.Error())
s.Debugf("Error adding stream export to account [%s]: %v", tl, err.Error())
}
case jwt.Service:
s.Debugf("Adding service export %q for %s", e.Subject, a.traceLabel())
s.Debugf("Adding service export %q for %s", e.Subject, tl)
rt := Singleton
switch e.ResponseType {
case jwt.ResponseTypeStream:
Expand All @@ -3301,7 +3307,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}
if err := a.addServiceExportWithResponseAndAccountPos(
string(e.Subject), rt, authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil {
s.Debugf("Error adding service export to account [%s]: %v", a.traceLabel(), err)
s.Debugf("Error adding service export to account [%s]: %v", tl, err)
continue
}
sub := string(e.Subject)
Expand All @@ -3311,13 +3317,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
if e.Latency.Sampling == jwt.Headers {
hdrNote = " (using headers)"
}
s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, a.traceLabel(), err)
s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, tl, err)
}
}
if e.ResponseThreshold != 0 {
// Response threshold was set in options.
if err := a.SetServiceExportResponseThreshold(sub, e.ResponseThreshold); err != nil {
s.Debugf("Error adding service export response threshold for [%s]: %v", a.traceLabel(), err)
s.Debugf("Error adding service export response threshold for [%s]: %v", tl, err)
}
}
}
Expand Down Expand Up @@ -3362,44 +3368,41 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}
var incompleteImports []*jwt.Import
for _, i := range ac.Imports {
// check tmpAccounts with priority
var acc *Account
var err error
if v, ok := s.tmpAccounts.Load(i.Account); ok {
acc = v.(*Account)
} else {
acc, err = s.lookupAccount(i.Account)
}
acc, err := s.lookupAccount(i.Account)
if acc == nil || err != nil {
s.Errorf("Can't locate account [%s] for import of [%v] %s (err=%v)", i.Account, i.Subject, i.Type, err)
incompleteImports = append(incompleteImports, i)
continue
}
from := string(i.Subject)
to := i.GetTo()
// Capture trace labels.
acc.mu.RLock()
atl := acc.traceLabel()
acc.mu.RUnlock()
// Grab from and to
from, to := string(i.Subject), i.GetTo()
switch i.Type {
case jwt.Stream:
if i.LocalSubject != _EMPTY_ {
// set local subject implies to is empty
to = string(i.LocalSubject)
s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to)
s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to)
err = a.AddMappedStreamImportWithClaim(acc, from, to, i)
} else {
s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to)
s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to)
err = a.AddStreamImportWithClaim(acc, from, to, i)
}
if err != nil {
s.Debugf("Error adding stream import to account [%s]: %v", a.traceLabel(), err.Error())
s.Debugf("Error adding stream import to account [%s]: %v", tl, err.Error())
incompleteImports = append(incompleteImports, i)
}
case jwt.Service:
if i.LocalSubject != _EMPTY_ {
from = string(i.LocalSubject)
to = string(i.Subject)
}
s.Debugf("Adding service import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to)
s.Debugf("Adding service import %s:%q for %s:%q", atl, from, tl, to)
if err := a.AddServiceImportWithClaim(acc, from, to, i); err != nil {
s.Debugf("Error adding service import to account [%s]: %v", a.traceLabel(), err.Error())
s.Debugf("Error adding service import to account [%s]: %v", tl, err.Error())
incompleteImports = append(incompleteImports, i)
}
}
Expand Down Expand Up @@ -3570,7 +3573,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
// regardless of enabled or disabled. It handles both cases.
if jsEnabled {
if err := s.configJetStream(a); err != nil {
s.Errorf("Error configuring jetstream for account [%s]: %v", a.traceLabel(), err.Error())
s.Errorf("Error configuring jetstream for account [%s]: %v", tl, err.Error())
a.mu.Lock()
// Absent reload of js server cfg, this is going to be broken until js is disabled
a.incomplete = true
Expand Down Expand Up @@ -3707,8 +3710,13 @@ func (s *Server) buildInternalAccount(ac *jwt.AccountClaims) *Account {
// We don't want to register an account that is in the process of
// being built, however, to solve circular import dependencies, we
// need to store it here.
s.tmpAccounts.Store(ac.Subject, acc)
if v, loaded := s.tmpAccounts.LoadOrStore(ac.Subject, acc); loaded {
return v.(*Account)
}

// Update based on claims.
s.UpdateAccountClaims(acc, ac)

return acc
}

Expand Down
5 changes: 4 additions & 1 deletion server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2914,8 +2914,11 @@ func (c *client) addShadowSubscriptions(acc *Account, sub *subscription, enact b

// Add in the shadow subscription.
func (c *client) addShadowSub(sub *subscription, ime *ime, enact bool) (*subscription, error) {
im := ime.im
c.mu.Lock()
nsub := *sub // copy
c.mu.Unlock()

im := ime.im
nsub.im = im

if !im.usePub && ime.dyn && im.tr != nil {
Expand Down
195 changes: 192 additions & 3 deletions server/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package server

import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"errors"
Expand Down Expand Up @@ -1991,9 +1992,9 @@ func TestJWTAccountURLResolverPermanentFetchFailure(t *testing.T) {
importErrCnt++
}
case <-tmr.C:
// connecting and updating, each cause 3 traces (2 + 1 on iteration)
if importErrCnt != 6 {
t.Fatalf("Expected 6 debug traces, got %d", importErrCnt)
// connecting and updating, each cause 3 traces (2 + 1 on iteration) + 1 xtra fetch
if importErrCnt != 7 {
t.Fatalf("Expected 7 debug traces, got %d", importErrCnt)
}
return
}
Expand Down Expand Up @@ -6842,3 +6843,191 @@ func TestJWTAccountNATSResolverWrongCreds(t *testing.T) {
t.Fatalf("Expected auth error: %v", err)
}
}

// Issue 5480: https://github.com/nats-io/nats-server/issues/5480
func TestJWTImportsOnServerRestartAndClientsReconnect(t *testing.T) {
type namedCreds struct {
name string
creds nats.Option
}
preload := make(map[string]string)
users := make(map[string]*namedCreds)

// sys account
_, sysAcc, sysAccClaim := NewJwtAccountClaim("sys")
sysAccJWT, err := sysAccClaim.Encode(oKp)
require_NoError(t, err)
preload[sysAcc] = sysAccJWT

// main account, other accounts will import from this.
mainAccKP, mainAcc, mainAccClaim := NewJwtAccountClaim("main")
mainAccClaim.Exports.Add(&jwt.Export{
Type: jwt.Stream,
Subject: "city.>",
})

// main account user
mainUserClaim := jwt.NewUserClaims("publisher")
mainUserClaim.Permissions = jwt.Permissions{
Pub: jwt.Permission{
Allow: []string{"city.>"},
},
}
mainCreds := createUserCredsEx(t, mainUserClaim, mainAccKP)

// The main account will be importing from all other accounts.
maxAccounts := 100
for i := 0; i < maxAccounts; i++ {
name := fmt.Sprintf("secondary-%d", i)
accKP, acc, accClaim := NewJwtAccountClaim(name)

accClaim.Exports.Add(&jwt.Export{
Type: jwt.Stream,
Subject: "internal.*",
})
accClaim.Imports.Add(&jwt.Import{
Type: jwt.Stream,
Subject: jwt.Subject(fmt.Sprintf("city.%d-1.*", i)),
Account: mainAcc,
})

// main account imports from the secondary accounts
mainAccClaim.Imports.Add(&jwt.Import{
Type: jwt.Stream,
Subject: jwt.Subject(fmt.Sprintf("internal.%d", i)),
Account: acc,
})

accJWT, err := accClaim.Encode(oKp)
require_NoError(t, err)
preload[acc] = accJWT

userClaim := jwt.NewUserClaims("subscriber")
userClaim.Permissions = jwt.Permissions{
Sub: jwt.Permission{
Allow: []string{"city.>", "internal.*"},
},
Pub: jwt.Permission{
Allow: []string{"internal.*"},
},
}
userCreds := createUserCredsEx(t, userClaim, accKP)
users[acc] = &namedCreds{name, userCreds}
}
mainAccJWT, err := mainAccClaim.Encode(oKp)
require_NoError(t, err)
preload[mainAcc] = mainAccJWT

// Start the server with the preload.
resolverPreload, err := json.Marshal(preload)
require_NoError(t, err)
conf := createConfFile(t, []byte(fmt.Sprintf(`
listen: 127.0.0.1:4747
http: 127.0.0.1:8222
operator: %s
system_account: %s
resolver: MEM
resolver_preload: %s
`, ojwt, sysAcc, string(resolverPreload))))
s, _ := RunServerWithConfig(conf)
defer s.Shutdown()

// Have a connection ready for each one of the accounts.
type namedSub struct {
name string
sub *nats.Subscription
}
subs := make(map[string]*namedSub)
for acc, user := range users {
nc := natsConnect(t, s.ClientURL(), user.creds,
// Make the clients attempt to reconnect too fast,
// changing this to be above ~200ms mitigates the issue.
nats.ReconnectWait(15*time.Millisecond),
nats.Name(user.name),
nats.MaxReconnects(-1),
)
defer nc.Close()

sub, err := nc.SubscribeSync("city.>")
require_NoError(t, err)
subs[acc] = &namedSub{user.name, sub}
}

nc := natsConnect(t, s.ClientURL(), mainCreds, nats.ReconnectWait(15*time.Millisecond), nats.MaxReconnects(-1))
defer nc.Close()

send := func(t *testing.T) {
t.Helper()
for i := 0; i < maxAccounts; i++ {
nc.Publish(fmt.Sprintf("city.%d-1.A4BDB048-69DC-4F10-916C-2B998249DC11", i), []byte(fmt.Sprintf("test:%d", i)))
}
nc.Flush()
}

ctx, done := context.WithCancel(context.Background())
defer done()
go func() {
for range time.NewTicker(200 * time.Millisecond).C {
select {
case <-ctx.Done():
default:
}
send(t)
}
}()

receive := func(t *testing.T) {
t.Helper()
received := 0
for _, nsub := range subs {
// Drain first any pending messages.
pendingMsgs, _, _ := nsub.sub.Pending()
for i, _ := 0, 0; i < pendingMsgs; i++ {
nsub.sub.NextMsg(500 * time.Millisecond)
}

_, err = nsub.sub.NextMsg(500 * time.Millisecond)
if err != nil {
t.Logf("WRN: Failed to receive message on account %q: %v", nsub.name, err)
} else {
received++
}
}
if received < (maxAccounts / 2) {
t.Fatalf("Too many missed messages after restart. Received %d", received)
}
}
receive(t)
time.Sleep(1 * time.Second)

restart := func(t *testing.T) *Server {
t.Helper()
s.Shutdown()
s.WaitForShutdown()
s, _ = RunServerWithConfig(conf)

hctx, hcancel := context.WithTimeout(context.Background(), 5*time.Second)
defer hcancel()
for range time.NewTicker(2 * time.Second).C {
select {
case <-hctx.Done():
t.Logf("WRN: Timed out waiting for healthz from %s", s)
default:
}

status := s.healthz(nil)
if status.StatusCode == 200 {
return s
}
}
return nil
}

// Takes a few restarts for issue to show up.
for i := 0; i < 5; i++ {
s := restart(t)
defer s.Shutdown()
time.Sleep(2 * time.Second)
receive(t)
}
}
Loading

0 comments on commit e019d7a

Please sign in to comment.