Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIXED] Fix imports not being available for a client sometimes after a server restart #5589

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions server/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2889,9 +2889,12 @@ func (a *Account) isIssuerClaimTrusted(claims *jwt.ActivationClaims) bool {
// check is done with the account's name, not the pointer. This is used
// during config reload where we are comparing current and new config
// in which pointers are different.
// No lock is acquired in this function, so it is assumed that the
// import maps are not changed while this executes.
// Acquires `a` read lock, but `b` is assumed to not be accessed
// by anyone but the caller (`b` is not registered anywhere).
func (a *Account) checkStreamImportsEqual(b *Account) bool {
a.mu.RLock()
defer a.mu.RUnlock()

if len(a.imports.streams) != len(b.imports.streams) {
return false
}
Expand Down Expand Up @@ -3264,6 +3267,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
a.nameTag = ac.Name
a.tags = ac.Tags

// Grab trace label under lock.
tl := a.traceLabel()

var td string
var tds int
if ac.Trace != nil {
Expand Down Expand Up @@ -3297,10 +3303,10 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}
if a.imports.services != nil {
old.imports.services = make(map[string]*serviceImport, len(a.imports.services))
}
for k, v := range a.imports.services {
old.imports.services[k] = v
delete(a.imports.services, k)
for k, v := range a.imports.services {
old.imports.services[k] = v
delete(a.imports.services, k)
}
}

alteredScope := map[string]struct{}{}
Expand Down Expand Up @@ -3370,13 +3376,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
for _, e := range ac.Exports {
switch e.Type {
case jwt.Stream:
s.Debugf("Adding stream export %q for %s", e.Subject, a.traceLabel())
s.Debugf("Adding stream export %q for %s", e.Subject, tl)
if err := a.addStreamExportWithAccountPos(
string(e.Subject), authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil {
s.Debugf("Error adding stream export to account [%s]: %v", a.traceLabel(), err.Error())
s.Debugf("Error adding stream export to account [%s]: %v", tl, err.Error())
}
case jwt.Service:
s.Debugf("Adding service export %q for %s", e.Subject, a.traceLabel())
s.Debugf("Adding service export %q for %s", e.Subject, tl)
rt := Singleton
switch e.ResponseType {
case jwt.ResponseTypeStream:
Expand All @@ -3386,7 +3392,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}
if err := a.addServiceExportWithResponseAndAccountPos(
string(e.Subject), rt, authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil {
s.Debugf("Error adding service export to account [%s]: %v", a.traceLabel(), err)
s.Debugf("Error adding service export to account [%s]: %v", tl, err)
continue
}
sub := string(e.Subject)
Expand All @@ -3396,13 +3402,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
if e.Latency.Sampling == jwt.Headers {
hdrNote = " (using headers)"
}
s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, a.traceLabel(), err)
s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, tl, err)
}
}
if e.ResponseThreshold != 0 {
// Response threshold was set in options.
if err := a.SetServiceExportResponseThreshold(sub, e.ResponseThreshold); err != nil {
s.Debugf("Error adding service export response threshold for [%s]: %v", a.traceLabel(), err)
s.Debugf("Error adding service export response threshold for [%s]: %v", tl, err)
}
}
if err := a.SetServiceExportAllowTrace(sub, e.AllowTrace); err != nil {
Expand Down Expand Up @@ -3450,44 +3456,41 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}
var incompleteImports []*jwt.Import
for _, i := range ac.Imports {
// check tmpAccounts with priority
var acc *Account
var err error
if v, ok := s.tmpAccounts.Load(i.Account); ok {
acc = v.(*Account)
} else {
acc, err = s.lookupAccount(i.Account)
}
acc, err := s.lookupAccount(i.Account)
if acc == nil || err != nil {
s.Errorf("Can't locate account [%s] for import of [%v] %s (err=%v)", i.Account, i.Subject, i.Type, err)
incompleteImports = append(incompleteImports, i)
continue
}
from := string(i.Subject)
to := i.GetTo()
// Capture trace labels.
acc.mu.RLock()
atl := acc.traceLabel()
acc.mu.RUnlock()
// Grab from and to
from, to := string(i.Subject), i.GetTo()
switch i.Type {
case jwt.Stream:
if i.LocalSubject != _EMPTY_ {
// set local subject implies to is empty
to = string(i.LocalSubject)
s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to)
s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to)
err = a.AddMappedStreamImportWithClaim(acc, from, to, i)
} else {
s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to)
s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to)
err = a.AddStreamImportWithClaim(acc, from, to, i)
}
if err != nil {
s.Debugf("Error adding stream import to account [%s]: %v", a.traceLabel(), err.Error())
s.Debugf("Error adding stream import to account [%s]: %v", tl, err.Error())
incompleteImports = append(incompleteImports, i)
}
case jwt.Service:
if i.LocalSubject != _EMPTY_ {
from = string(i.LocalSubject)
to = string(i.Subject)
}
s.Debugf("Adding service import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to)
s.Debugf("Adding service import %s:%q for %s:%q", atl, from, tl, to)
if err := a.AddServiceImportWithClaim(acc, from, to, i); err != nil {
s.Debugf("Error adding service import to account [%s]: %v", a.traceLabel(), err.Error())
s.Debugf("Error adding service import to account [%s]: %v", tl, err.Error())
incompleteImports = append(incompleteImports, i)
}
}
Expand Down Expand Up @@ -3663,7 +3666,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
// regardless of enabled or disabled. It handles both cases.
if jsEnabled {
if err := s.configJetStream(a); err != nil {
s.Errorf("Error configuring jetstream for account [%s]: %v", a.traceLabel(), err.Error())
s.Errorf("Error configuring jetstream for account [%s]: %v", tl, err.Error())
a.mu.Lock()
// Absent reload of js server cfg, this is going to be broken until js is disabled
a.incomplete = true
Expand Down Expand Up @@ -3802,8 +3805,13 @@ func (s *Server) buildInternalAccount(ac *jwt.AccountClaims) *Account {
// We don't want to register an account that is in the process of
// being built, however, to solve circular import dependencies, we
// need to store it here.
s.tmpAccounts.Store(ac.Subject, acc)
if v, loaded := s.tmpAccounts.LoadOrStore(ac.Subject, acc); loaded {
return v.(*Account)
}

// Update based on claims.
s.UpdateAccountClaims(acc, ac)

return acc
}

Expand Down
5 changes: 4 additions & 1 deletion server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2959,8 +2959,11 @@ func (c *client) addShadowSubscriptions(acc *Account, sub *subscription, enact b

// Add in the shadow subscription.
func (c *client) addShadowSub(sub *subscription, ime *ime, enact bool) (*subscription, error) {
im := ime.im
c.mu.Lock()
nsub := *sub // copy
c.mu.Unlock()

im := ime.im
nsub.im = im

if !im.usePub && ime.dyn && im.tr != nil {
Expand Down
195 changes: 192 additions & 3 deletions server/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package server

import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"errors"
Expand Down Expand Up @@ -1991,9 +1992,9 @@ func TestJWTAccountURLResolverPermanentFetchFailure(t *testing.T) {
importErrCnt++
}
case <-tmr.C:
// connecting and updating, each cause 3 traces (2 + 1 on iteration)
if importErrCnt != 6 {
t.Fatalf("Expected 6 debug traces, got %d", importErrCnt)
// connecting and updating, each cause 3 traces (2 + 1 on iteration) + 1 xtra fetch
if importErrCnt != 7 {
t.Fatalf("Expected 7 debug traces, got %d", importErrCnt)
}
return
}
Expand Down Expand Up @@ -6842,3 +6843,191 @@ func TestJWTAccountNATSResolverWrongCreds(t *testing.T) {
t.Fatalf("Expected auth error: %v", err)
}
}

// Issue 5480: https://github.com/nats-io/nats-server/issues/5480
func TestJWTImportsOnServerRestartAndClientsReconnect(t *testing.T) {
type namedCreds struct {
name string
creds nats.Option
}
preload := make(map[string]string)
users := make(map[string]*namedCreds)

// sys account
_, sysAcc, sysAccClaim := NewJwtAccountClaim("sys")
sysAccJWT, err := sysAccClaim.Encode(oKp)
require_NoError(t, err)
preload[sysAcc] = sysAccJWT

// main account, other accounts will import from this.
mainAccKP, mainAcc, mainAccClaim := NewJwtAccountClaim("main")
mainAccClaim.Exports.Add(&jwt.Export{
Type: jwt.Stream,
Subject: "city.>",
})

// main account user
mainUserClaim := jwt.NewUserClaims("publisher")
mainUserClaim.Permissions = jwt.Permissions{
Pub: jwt.Permission{
Allow: []string{"city.>"},
},
}
mainCreds := createUserCredsEx(t, mainUserClaim, mainAccKP)

// The main account will be importing from all other accounts.
maxAccounts := 100
for i := 0; i < maxAccounts; i++ {
name := fmt.Sprintf("secondary-%d", i)
accKP, acc, accClaim := NewJwtAccountClaim(name)

accClaim.Exports.Add(&jwt.Export{
Type: jwt.Stream,
Subject: "internal.*",
})
accClaim.Imports.Add(&jwt.Import{
Type: jwt.Stream,
Subject: jwt.Subject(fmt.Sprintf("city.%d-1.*", i)),
Account: mainAcc,
})

// main account imports from the secondary accounts
mainAccClaim.Imports.Add(&jwt.Import{
Type: jwt.Stream,
Subject: jwt.Subject(fmt.Sprintf("internal.%d", i)),
Account: acc,
})

accJWT, err := accClaim.Encode(oKp)
require_NoError(t, err)
preload[acc] = accJWT

userClaim := jwt.NewUserClaims("subscriber")
userClaim.Permissions = jwt.Permissions{
Sub: jwt.Permission{
Allow: []string{"city.>", "internal.*"},
},
Pub: jwt.Permission{
Allow: []string{"internal.*"},
},
}
userCreds := createUserCredsEx(t, userClaim, accKP)
users[acc] = &namedCreds{name, userCreds}
}
mainAccJWT, err := mainAccClaim.Encode(oKp)
require_NoError(t, err)
preload[mainAcc] = mainAccJWT

// Start the server with the preload.
resolverPreload, err := json.Marshal(preload)
require_NoError(t, err)
conf := createConfFile(t, []byte(fmt.Sprintf(`
listen: 127.0.0.1:4747
http: 127.0.0.1:8222
operator: %s
system_account: %s
resolver: MEM
resolver_preload: %s
`, ojwt, sysAcc, string(resolverPreload))))
s, _ := RunServerWithConfig(conf)
defer s.Shutdown()

// Have a connection ready for each one of the accounts.
type namedSub struct {
name string
sub *nats.Subscription
}
subs := make(map[string]*namedSub)
for acc, user := range users {
nc := natsConnect(t, s.ClientURL(), user.creds,
// Make the clients attempt to reconnect too fast,
// changing this to be above ~200ms mitigates the issue.
nats.ReconnectWait(15*time.Millisecond),
nats.Name(user.name),
nats.MaxReconnects(-1),
)
defer nc.Close()

sub, err := nc.SubscribeSync("city.>")
require_NoError(t, err)
subs[acc] = &namedSub{user.name, sub}
}

nc := natsConnect(t, s.ClientURL(), mainCreds, nats.ReconnectWait(15*time.Millisecond), nats.MaxReconnects(-1))
defer nc.Close()

send := func(t *testing.T) {
t.Helper()
for i := 0; i < maxAccounts; i++ {
nc.Publish(fmt.Sprintf("city.%d-1.A4BDB048-69DC-4F10-916C-2B998249DC11", i), []byte(fmt.Sprintf("test:%d", i)))
}
nc.Flush()
}

ctx, done := context.WithCancel(context.Background())
defer done()
go func() {
for range time.NewTicker(200 * time.Millisecond).C {
select {
case <-ctx.Done():
default:
}
send(t)
}
}()

receive := func(t *testing.T) {
t.Helper()
received := 0
for _, nsub := range subs {
// Drain first any pending messages.
pendingMsgs, _, _ := nsub.sub.Pending()
for i, _ := 0, 0; i < pendingMsgs; i++ {
nsub.sub.NextMsg(500 * time.Millisecond)
}

_, err = nsub.sub.NextMsg(500 * time.Millisecond)
if err != nil {
t.Logf("WRN: Failed to receive message on account %q: %v", nsub.name, err)
} else {
received++
}
}
if received < (maxAccounts / 2) {
t.Fatalf("Too many missed messages after restart. Received %d", received)
}
}
receive(t)
time.Sleep(1 * time.Second)

restart := func(t *testing.T) *Server {
t.Helper()
s.Shutdown()
s.WaitForShutdown()
s, _ = RunServerWithConfig(conf)

hctx, hcancel := context.WithTimeout(context.Background(), 5*time.Second)
defer hcancel()
for range time.NewTicker(2 * time.Second).C {
select {
case <-hctx.Done():
t.Logf("WRN: Timed out waiting for healthz from %s", s)
default:
}

status := s.healthz(nil)
if status.StatusCode == 200 {
return s
}
}
return nil
}

// Takes a few restarts for issue to show up.
for i := 0; i < 5; i++ {
s := restart(t)
defer s.Shutdown()
time.Sleep(2 * time.Second)
receive(t)
}
}
Loading