diff --git a/management/server/account.go b/management/server/account.go index e2293e5081d..ac00462fab1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -119,7 +119,8 @@ type DefaultAccountManager struct { // singleAccountModeDomain is a domain to use in singleAccountMode setup singleAccountModeDomain string // dnsDomain is used for peer resolution. This is appended to the peer's name - dnsDomain string + dnsDomain string + peerLoginExpiry Scheduler } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -307,6 +308,58 @@ func (a *Account) GetGroup(groupID string) *Group { return a.Groups[groupID] } +// GetExpiredPeers returns peers that have been expired +func (a *Account) GetExpiredPeers() []*Peer { + var peers []*Peer + for _, peer := range a.GetPeersWithExpiration() { + expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers +} + +// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { + + peersWithExpiry := a.GetPeersWithExpiration() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true +func (a *Account) GetPeersWithExpiration() []*Peer { + peers := make([]*Peer, 0) + for _, peer := range a.Peers { + if peer.LoginExpirationEnabled { + peers = append(peers, peer) + } + } + return peers +} + // GetPeers returns a list of all Account peers func (a *Account) GetPeers() []*Peer { var peers []*Peer @@ -550,13 +603,14 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage cacheLoading: map[string]chan struct{}{}, dnsDomain: dnsDomain, eventStore: eventStore, + peerLoginExpiry: NewDefaultScheduler(), } allAccounts := store.GetAllAccounts() // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 if am.singleAccountMode { if !isDomainValid(singleAccountModeDomain) { - return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for single accound mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) + return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain log.Infof("single account mode enabled, accounts number %d", len(allAccounts)) @@ -640,12 +694,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { event = activity.AccountPeerLoginExpirationDisabled + am.peerLoginExpiry.Cancel([]string{accountID}) + } else { + am.checkAndSchedulePeerLoginExpiration(account) } am.storeEvent(userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.storeEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) + am.checkAndSchedulePeerLoginExpiration(account) } updatedAccount := account.UpdateSettings(newSettings) @@ -658,6 +716,54 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return updatedAccount, nil } +func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) { + return func() (time.Duration, bool) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + log.Errorf("failed getting account %s expiring peers", account.Id) + return account.GetNextPeerExpiration() + } + + var peerIDs []string + for _, peer := range account.GetExpiredPeers() { + if peer.Status.LoginExpired { + continue + } + peerIDs = append(peerIDs, peer.ID) + peer.MarkLoginExpired(true) + account.UpdatePeer(peer) + err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) + if err != nil { + log.Errorf("failed saving peer status while expiring peer %s", peer.ID) + return account.GetNextPeerExpiration() + } + } + + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + + if len(peerIDs) != 0 { + // this will trigger peer disconnect from the management service + am.peersUpdateManager.CloseChannels(peerIDs) + err := am.updateAccountPeers(account) + if err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", accountID) + return account.GetNextPeerExpiration() + } + } + return account.GetNextPeerExpiration() + } +} + +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(account *Account) { + am.peerLoginExpiry.Cancel([]string{account.Id}) + if nextRun, ok := account.GetNextPeerExpiration(); ok { + go am.peerLoginExpiry.Schedule(nextRun, account.Id, am.peerLoginExpirationJob(account.Id)) + } +} + // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 979c41c86dd..1d672e1b780 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1294,6 +1294,147 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true) assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour) } +func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + peer, err := manager.AddPeer("", userID, &Peer{ + Key: key.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer", + LoginExpirationEnabled: true, + }) + require.NoError(t, err, "unable to add peer") + err = manager.MarkPeerConnected(key.PublicKey().String(), true) + require.NoError(t, err, "unable to mark peer connected") + account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + PeerLoginExpiration: time.Hour, + PeerLoginExpirationEnabled: true}) + require.NoError(t, err, "expecting to update account settings successfully but got error") + + wg := &sync.WaitGroup{} + wg.Add(2) + manager.peerLoginExpiry = &MockScheduler{ + CancelFunc: func(IDs []string) { + wg.Done() + }, + ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + wg.Done() + }, + } + + // disable expiration first + update := peer.Copy() + update.LoginExpirationEnabled = false + _, err = manager.UpdatePeer(account.Id, userID, update) + require.NoError(t, err, "unable to update peer") + // enabling expiration should trigger the routine + update.LoginExpirationEnabled = true + _, err = manager.UpdatePeer(account.Id, userID, update) + require.NoError(t, err, "unable to update peer") + + failed := waitTimeout(wg, time.Second) + if failed { + t.Fatal("timeout while waiting for test to finish") + } +} + +func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + _, err = manager.AddPeer("", userID, &Peer{ + Key: key.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer", + LoginExpirationEnabled: true, + }) + require.NoError(t, err, "unable to add peer") + _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + PeerLoginExpiration: time.Hour, + PeerLoginExpirationEnabled: true}) + require.NoError(t, err, "expecting to update account settings successfully but got error") + + wg := &sync.WaitGroup{} + wg.Add(2) + manager.peerLoginExpiry = &MockScheduler{ + CancelFunc: func(IDs []string) { + wg.Done() + }, + ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + wg.Done() + }, + } + + // when we mark peer as connected, the peer login expiration routine should trigger + err = manager.MarkPeerConnected(key.PublicKey().String(), true) + require.NoError(t, err, "unable to mark peer connected") + + failed := waitTimeout(wg, time.Second) + if failed { + t.Fatal("timeout while waiting for test to finish") + } + +} + +func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + _, err = manager.AddPeer("", userID, &Peer{ + Key: key.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer", + LoginExpirationEnabled: true, + }) + require.NoError(t, err, "unable to add peer") + err = manager.MarkPeerConnected(key.PublicKey().String(), true) + require.NoError(t, err, "unable to mark peer connected") + + wg := &sync.WaitGroup{} + wg.Add(2) + manager.peerLoginExpiry = &MockScheduler{ + CancelFunc: func(IDs []string) { + wg.Done() + }, + ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + wg.Done() + }, + } + // enabling PeerLoginExpirationEnabled should trigger the expiration job + account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + PeerLoginExpiration: time.Hour, + PeerLoginExpirationEnabled: true}) + require.NoError(t, err, "expecting to update account settings successfully but got error") + + failed := waitTimeout(wg, time.Second) + if failed { + t.Fatal("timeout while waiting for test to finish") + } + wg.Add(1) + + // disabling PeerLoginExpirationEnabled should trigger cancel + _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + PeerLoginExpiration: time.Hour, + PeerLoginExpirationEnabled: false}) + require.NoError(t, err, "expecting to update account settings successfully but got error") + failed = waitTimeout(wg, time.Second) + if failed { + t.Fatal("timeout while waiting for test to finish") + } +} func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) @@ -1326,6 +1467,286 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") } +func TestAccount_GetExpiredPeers(t *testing.T) { + type test struct { + name string + peers map[string]*Peer + expectedPeers map[string]struct{} + } + testCases := []test{ + { + name: "Peers with login expiration disabled, no expired peers", + peers: map[string]*Peer{ + "peer-1": { + LoginExpirationEnabled: false, + }, + "peer-2": { + LoginExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Two peers expired", + peers: map[string]*Peer{ + "peer-1": { + ID: "peer-1", + LoginExpirationEnabled: true, + Status: &PeerStatus{ + LastSeen: time.Now(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().Add(-30 * time.Minute), + }, + "peer-2": { + ID: "peer-2", + LoginExpirationEnabled: true, + Status: &PeerStatus{ + LastSeen: time.Now(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().Add(-2 * time.Hour), + }, + + "peer-3": { + ID: "peer-3", + LoginExpirationEnabled: true, + Status: &PeerStatus{ + LastSeen: time.Now(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().Add(-1 * time.Hour), + }, + }, + expectedPeers: map[string]struct{}{ + "peer-2": {}, + "peer-3": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour, + }, + } + + expiredPeers := account.GetExpiredPeers() + assert.Len(t, expiredPeers, len(testCase.expectedPeers)) + for _, peer := range expiredPeers { + if _, ok := testCase.expectedPeers[peer.ID]; !ok { + t.Fatalf("expected to have peer %s expired", peer.ID) + } + } + }) + } + +} + +func TestAccount_GetPeersWithExpiration(t *testing.T) { + type test struct { + name string + peers map[string]*Peer + expectedPeers map[string]struct{} + } + + testCases := []test{ + { + name: "No account peers, no peers with expiration", + peers: map[string]*Peer{}, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration disabled, no peers with expiration", + peers: map[string]*Peer{ + "peer-1": { + LoginExpirationEnabled: false, + }, + "peer-2": { + LoginExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration enabled, return peers with expiration", + peers: map[string]*Peer{ + "peer-1": { + ID: "peer-1", + LoginExpirationEnabled: true, + }, + "peer-2": { + LoginExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + } + + actual := account.GetPeersWithExpiration() + assert.Len(t, actual, len(testCase.expectedPeers)) + if len(testCase.expectedPeers) > 0 { + for k := range testCase.expectedPeers { + contains := false + for _, peer := range actual { + if k == peer.ID { + contains = true + } + } + assert.True(t, contains) + } + } + }) + } + +} + +func TestAccount_GetNextPeerExpiration(t *testing.T) { + + type test struct { + name string + peers map[string]*Peer + expiration time.Duration + expirationEnabled bool + expectedNextRun bool + expectedNextExpiration time.Duration + } + + expectedNextExpiration := time.Minute + testCases := []test{ + { + name: "No peers, no expiration", + peers: map[string]*Peer{}, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "No connected peers, no expiration", + peers: map[string]*Peer{ + "peer-1": { + Status: &PeerStatus{ + Connected: false, + }, + LoginExpirationEnabled: true, + }, + "peer-2": { + Status: &PeerStatus{ + Connected: true, + }, + LoginExpirationEnabled: false, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Connected peers with disabled expiration, no expiration", + peers: map[string]*Peer{ + "peer-1": { + Status: &PeerStatus{ + Connected: true, + }, + LoginExpirationEnabled: false, + }, + "peer-2": { + Status: &PeerStatus{ + Connected: true, + }, + LoginExpirationEnabled: false, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Expired peers, no expiration", + peers: map[string]*Peer{ + "peer-1": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: true, + }, + LoginExpirationEnabled: true, + }, + "peer-2": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: true, + }, + LoginExpirationEnabled: true, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "To be expired peer, return expiration", + peers: map[string]*Peer{ + "peer-1": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: false, + }, + LoginExpirationEnabled: true, + LastLogin: time.Now(), + }, + "peer-2": { + Status: &PeerStatus{ + Connected: true, + LoginExpired: true, + }, + LoginExpirationEnabled: true, + }, + }, + expiration: time.Minute, + expirationEnabled: false, + expectedNextRun: true, + expectedNextExpiration: expectedNextExpiration, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, + } + + expiration, ok := account.GetNextPeerExpiration() + assert.Equal(t, ok, testCase.expectedNextRun) + if testCase.expectedNextRun { + assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration) + } else { + assert.Equal(t, expiration, testCase.expectedNextExpiration) + } + + }) + } + +} + func createManager(t *testing.T) (*DefaultAccountManager, error) { store, err := createStore(t) if err != nil { @@ -1344,3 +1765,17 @@ func createStore(t *testing.T) (Store, error) { return store, nil } + +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: + return false + case <-time.After(timeout): + return true + } +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 7d4d60207ad..0ee9e0715b3 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -136,7 +136,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi if err != nil { return status.Error(codes.Internal, "internal server error") } - expired, left := peer.LoginExpired(account.Settings) + expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration) + expired = account.Settings.PeerLoginExpirationEnabled && expired if peer.UserID != "" && (expired || peer.Status.LoginExpired) { err = s.accountManager.MarkPeerLoginExpired(peerKey.String(), true) if err != nil { @@ -380,7 +381,9 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p if err != nil { return nil, status.Error(codes.Internal, "internal server error") } - expired, left := peer.LoginExpired(account.Settings) + + expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration) + expired = account.Settings.PeerLoginExpirationEnabled && expired if peer.UserID != "" && (expired || peer.Status.LoginExpired) { // it might be that peer expired but user has logged in already, check token then if loginReq.GetJwtToken() == "" { diff --git a/management/server/peer.go b/management/server/peer.go index 5e3f5e69bf2..49732421f09 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -93,17 +93,24 @@ func (p *Peer) Copy() *Peer { } } +// MarkLoginExpired marks peer's status expired or not +func (p *Peer) MarkLoginExpired(expired bool) { + newStatus := p.Status.Copy() + newStatus.LastSeen = time.Now() + newStatus.LoginExpired = expired + p.Status = newStatus +} + // LoginExpired indicates whether the peer's login has expired or not. // If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. // Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). // Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. -// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled -// and if disabled on the Account level, then Peer.LoginExpirationEnabled is ineffective. -func (p *Peer) LoginExpired(accountSettings *Settings) (bool, time.Duration) { - expiresAt := p.LastLogin.Add(accountSettings.PeerLoginExpiration) +// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled. +func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) { + expiresAt := p.LastLogin.Add(expiresIn) now := time.Now() timeLeft := expiresAt.Sub(now) - return accountSettings.PeerLoginExpirationEnabled && p.LoginExpirationEnabled && (timeLeft <= 0), timeLeft + return p.LoginExpirationEnabled && (timeLeft <= 0), timeLeft } // FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain @@ -202,13 +209,10 @@ func (am *DefaultAccountManager) MarkPeerLoginExpired(peerPubKey string, loginEx return err } - newStatus := peer.Status.Copy() - newStatus.LastSeen = time.Now() - newStatus.LoginExpired = loginExpired - peer.Status = newStatus + peer.MarkLoginExpired(loginExpired) account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) if err != nil { return err } @@ -237,7 +241,8 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected return err } - newStatus := peer.Status.Copy() + oldStatus := peer.Status.Copy() + newStatus := oldStatus newStatus.LastSeen = time.Now() newStatus.Connected = connected // whenever peer got connected that means that it logged in successfully @@ -251,6 +256,20 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected if err != nil { return err } + + if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(account) + } + + if oldStatus.LoginExpired { + // we need to update other peers because when peer login expires all other peers are notified to disconnect from + //the expired one. Here we notify them that connection is now allowed again. + err = am.updateAccountPeers(account) + if err != nil { + return err + } + } + return nil } @@ -307,6 +326,10 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Pe event = activity.PeerLoginExpirationDisabled } am.storeEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + + if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(account) + } } account.UpdatePeer(peer) @@ -529,7 +552,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* SSHEnabled: false, SSHKey: peer.SSHKey, LastLogin: time.Now(), - LoginExpirationEnabled: false, + LoginExpirationEnabled: true, } // add peer to 'All' group @@ -775,6 +798,10 @@ func (a *Account) getPeersByACL(peerID string) []*Peer { ) continue } + expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if expired { + continue + } // exclude original peer if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok { peersSet[peer.ID] = struct{}{} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index eb503d2184b..5ebbad4ecca 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -57,7 +57,7 @@ func TestPeer_LoginExpired(t *testing.T) { LastLogin: c.lastLogin, } - expired, _ := peer.LoginExpired(c.accountSettings) + expired, _ := peer.LoginExpired(c.accountSettings.PeerLoginExpiration) assert.Equal(t, expired, c.expected) }) } diff --git a/management/server/scheduler.go b/management/server/scheduler.go new file mode 100644 index 00000000000..a35bdc30ce1 --- /dev/null +++ b/management/server/scheduler.go @@ -0,0 +1,114 @@ +package server + +import ( + log "github.com/sirupsen/logrus" + "sync" + "time" +) + +// Scheduler is an interface which implementations can schedule and cancel jobs +type Scheduler interface { + Cancel(IDs []string) + Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) +} + +// MockScheduler is a mock implementation of Scheduler +type MockScheduler struct { + CancelFunc func(IDs []string) + ScheduleFunc func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) +} + +// Cancel mocks the Cancel function of the Scheduler interface +func (mock *MockScheduler) Cancel(IDs []string) { + if mock.CancelFunc != nil { + mock.CancelFunc(IDs) + return + } + log.Errorf("MockScheduler doesn't have Cancel function defined ") +} + +// Schedule mocks the Schedule function of the Scheduler interface +func (mock *MockScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + if mock.ScheduleFunc != nil { + mock.ScheduleFunc(in, ID, job) + return + } + log.Errorf("MockScheduler doesn't have Schedule function defined") +} + +// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. +type DefaultScheduler struct { + // jobs map holds cancellation channels indexed by the job ID + jobs map[string]chan struct{} + mu *sync.Mutex +} + +// NewDefaultScheduler creates an instance of a DefaultScheduler +func NewDefaultScheduler() *DefaultScheduler { + return &DefaultScheduler{ + jobs: make(map[string]chan struct{}), + mu: &sync.Mutex{}, + } +} + +func (wm *DefaultScheduler) cancel(ID string) bool { + cancel, ok := wm.jobs[ID] + if ok { + delete(wm.jobs, ID) + select { + case cancel <- struct{}{}: + log.Debugf("cancelled scheduled job %s", ID) + default: + log.Warnf("couldn't cancel job %s because there was no routine listening on the cancel event", ID) + return false + } + + } + return ok +} + +// Cancel cancels the scheduled job by ID if present. +// If job wasn't found the function returns false. +func (wm *DefaultScheduler) Cancel(IDs []string) { + wm.mu.Lock() + defer wm.mu.Unlock() + + for _, id := range IDs { + wm.cancel(id) + } +} + +// Schedule a job to run in some time in the future. If job returns true then it will be scheduled one more time. +// If job with the provided ID already exists, a new one won't be scheduled. +func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + wm.mu.Lock() + defer wm.mu.Unlock() + cancel := make(chan struct{}) + if _, ok := wm.jobs[ID]; ok { + log.Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.", + ID, len(wm.jobs)) + return + } + + wm.jobs[ID] = cancel + log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs)) + go func() { + select { + case <-time.After(in): + log.Debugf("time to do a scheduled job %s", ID) + runIn, reschedule := job() + wm.mu.Lock() + defer wm.mu.Unlock() + delete(wm.jobs, ID) + if reschedule { + go wm.Schedule(runIn, ID, job) + } + case <-cancel: + log.Debugf("stopped scheduled job %s ", ID) + wm.mu.Lock() + defer wm.mu.Unlock() + delete(wm.jobs, ID) + return + } + }() +} diff --git a/management/server/scheduler_test.go b/management/server/scheduler_test.go new file mode 100644 index 00000000000..0c0cef99b35 --- /dev/null +++ b/management/server/scheduler_test.go @@ -0,0 +1,94 @@ +package server + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "math/rand" + "sync" + "testing" + "time" +) + +func TestScheduler_Performance(t *testing.T) { + scheduler := NewDefaultScheduler() + n := 500 + wg := &sync.WaitGroup{} + wg.Add(n) + maxMs := 500 + minMs := 50 + for i := 0; i < n; i++ { + millis := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond + go scheduler.Schedule(millis, fmt.Sprintf("test-scheduler-job-%d", i), func() (nextRunIn time.Duration, reschedule bool) { + time.Sleep(millis) + wg.Done() + return 0, false + }) + } + failed := waitTimeout(wg, 3*time.Second) + if failed { + t.Fatal("timed out while waiting for test to finish") + return + } + assert.Len(t, scheduler.jobs, 0) +} + +func TestScheduler_Cancel(t *testing.T) { + jobID1 := "test-scheduler-job-1" + jobID2 := "test-scheduler-job-2" + scheduler := NewDefaultScheduler() + scheduler.Schedule(2*time.Second, jobID1, func() (nextRunIn time.Duration, reschedule bool) { + return 0, false + }) + scheduler.Schedule(2*time.Second, jobID2, func() (nextRunIn time.Duration, reschedule bool) { + return 0, false + }) + + assert.Len(t, scheduler.jobs, 2) + scheduler.Cancel([]string{jobID1}) + assert.Len(t, scheduler.jobs, 1) + assert.NotNil(t, scheduler.jobs[jobID2]) +} + +func TestScheduler_Schedule(t *testing.T) { + jobID := "test-scheduler-job-1" + scheduler := NewDefaultScheduler() + wg := &sync.WaitGroup{} + wg.Add(1) + // job without reschedule should be triggered once + job := func() (nextRunIn time.Duration, reschedule bool) { + wg.Done() + return 0, false + } + scheduler.Schedule(300*time.Millisecond, jobID, job) + failed := waitTimeout(wg, time.Second) + if failed { + t.Fatal("timed out while waiting for test to finish") + return + } + + // job with reschedule should be triggered at least twice + wg = &sync.WaitGroup{} + mx := &sync.Mutex{} + scheduledTimes := 0 + wg.Add(2) + job = func() (nextRunIn time.Duration, reschedule bool) { + mx.Lock() + defer mx.Unlock() + // ensure we repeat only twice + if scheduledTimes < 2 { + wg.Done() + scheduledTimes++ + return 300 * time.Millisecond, true + } + return 0, false + } + + scheduler.Schedule(300*time.Millisecond, jobID, job) + failed = waitTimeout(wg, time.Second) + if failed { + t.Fatal("timed out while waiting for test to finish") + return + } + scheduler.cancel(jobID) + +} diff --git a/management/server/turncredentials.go b/management/server/turncredentials.go index dcfab57dd6d..752376767b4 100644 --- a/management/server/turncredentials.go +++ b/management/server/turncredentials.go @@ -115,6 +115,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) { Turns: turns, }, } + log.Debugf("sending new TURN credentials to peer %s", peerID) err := m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update}) if err != nil { log.Errorf("error while sending TURN update to peer %s %v", peerID, err) diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 4b4d6e3d198..6cc10ad246c 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -60,10 +60,7 @@ func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage { return channel } -// CloseChannel closes updates channel of a given peer -func (p *PeersUpdateManager) CloseChannel(peerID string) { - p.channelsMux.Lock() - defer p.channelsMux.Unlock() +func (p *PeersUpdateManager) closeChannel(peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) @@ -72,6 +69,22 @@ func (p *PeersUpdateManager) CloseChannel(peerID string) { log.Debugf("closed updates channel of a peer %s", peerID) } +// CloseChannels closes updates channel for each given peer +func (p *PeersUpdateManager) CloseChannels(peerIDs []string) { + p.channelsMux.Lock() + defer p.channelsMux.Unlock() + for _, id := range peerIDs { + p.closeChannel(id) + } +} + +// CloseChannel closes updates channel of a given peer +func (p *PeersUpdateManager) CloseChannel(peerID string) { + p.channelsMux.Lock() + defer p.channelsMux.Unlock() + p.closeChannel(peerID) +} + // GetAllConnectedPeers returns a copy of the connected peers map func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} { p.channelsMux.Lock()