diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index a59bd2c602e..adb8f20ef5c 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early to ensure cleanup of chains - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index ea8912f27f5..3f8fac249a5 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } diff --git a/client/internal/config.go b/client/internal/config.go index ee54c6380c5..ce87835cdd5 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { if err != nil { return nil, err } - err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) + err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg) return cfg, err } @@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) { // WriteOutConfig write put the prepared config to the given path func WriteOutConfig(path string, config *Config) error { - return util.WriteJson(path, config) + return util.WriteJson(context.Background(), path, config) } // createNewConfig creates a new config generating a new Wireguard key and saving to file @@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) { } if updated { - if err := util.WriteJson(input.ConfigPath, config); err != nil { + if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil { return nil, err } } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 929e1e60c85..6c4dccae74a 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -326,9 +326,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // persist dns state right away ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - log.Errorf("Failed to persist dns state: %v", err) - } + + // don't block + go func() { + if err := s.stateManager.PersistState(ctx); err != nil { + log.Errorf("Failed to persist dns state: %v", err) + } + }() if s.searchDomainNotifier != nil { s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) diff --git a/client/internal/engine.go b/client/internal/engine.go index 190d795cdbe..0f3a5d28af0 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -38,7 +38,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -171,7 +170,7 @@ type Engine struct { relayManager *relayClient.Manager stateManager *statemanager.Manager - srWatcher *guard.SRWatcher + srWatcher *guard.SRWatcher } // Peer is an instance of the Connection Peer @@ -641,6 +640,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { } func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { + if e.wgInterface == nil { + return errors.New("wireguard interface is not initialized") + } + if e.wgInterface.Address().String() != conf.Address { oldAddr := e.wgInterface.Address().String() log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index 583156e4d86..56fcf1613c6 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -1,6 +1,7 @@ package routemanager import ( + "fmt" "net/netip" "testing" "time" @@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) { currentRoute: "route1", expectedRouteID: "route1", }, + { + name: "relayed routes with latency 0 should maintain previous choice", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + connected: true, + relayed: true, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: true, + latency: 0 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, + { + name: "p2p routes with latency 0 should maintain previous choice", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + latency: 0 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, { name: "current route with bad score should be changed to route with better score", statuses: map[route.ID]routerPeerStatus{ @@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, } + // fill the test data with random routes + for _, tc := range testCases { + for i := 0; i < 50; i++ { + dummyRoute := &route.Route{ + ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)), + Metric: route.MinMetric, + Peer: fmt.Sprintf("dummy_p1_%d", i), + } + tc.existingRoutes[dummyRoute.ID] = dummyRoute + } + for i := 0; i < 50; i++ { + dummyRoute := &route.Route{ + ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)), + Metric: route.MinMetric, + Peer: fmt.Sprintf("dummy_p1_%d", i), + } + tc.existingRoutes[dummyRoute.ID] = dummyRoute + } + + for i := 0; i < 50; i++ { + id := route.ID(fmt.Sprintf("dummy_p1_%d", i)) + dummyStatus := routerPeerStatus{ + connected: false, + relayed: true, + latency: 0, + } + tc.statuses[id] = dummyStatus + } + for i := 0; i < 50; i++ { + id := route.ID(fmt.Sprintf("dummy_p2_%d", i)) + dummyStatus := routerPeerStatus{ + connected: false, + relayed: true, + latency: 0, + } + tc.statuses[id] = dummyStatus + } + } + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { currentRoute := &route.Route{ diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index a5a14f807a2..580ccdfc78a 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -16,6 +16,7 @@ import ( "golang.org/x/exp/maps" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/util" ) // State interface defines the methods that all state types must implement @@ -178,25 +179,14 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) + start := time.Now() go func() { - data, err := json.MarshalIndent(m.states, "", " ") - if err != nil { - done <- fmt.Errorf("marshal states: %w", err) - return - } - - // nolint:gosec - if err := os.WriteFile(m.filePath, data, 0640); err != nil { - done <- fmt.Errorf("write state file: %w", err) - return - } - - done <- nil + done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) }() select { @@ -208,7 +198,7 @@ func (m *Manager) PersistState(ctx context.Context) error { } } - log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) + log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start)) clear(m.dirty) diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go index 96d6a9f12d3..6cfd79a1212 100644 --- a/client/internal/statemanager/path.go +++ b/client/internal/statemanager/path.go @@ -4,32 +4,20 @@ import ( "os" "path/filepath" "runtime" - - log "github.com/sirupsen/logrus" ) // GetDefaultStatePath returns the path to the state file based on the operating system -// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. +// It returns an empty string if the path cannot be determined. func GetDefaultStatePath() string { - var path string - switch runtime.GOOS { case "windows": - path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") + return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") case "darwin", "linux": - path = "/var/lib/netbird/state.json" + return "/var/lib/netbird/state.json" case "freebsd", "openbsd", "netbsd", "dragonfly": - path = "/var/db/netbird/state.json" - // ios/android don't need state - default: - return "" + return "/var/db/netbird/state.json" } - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0755); err != nil { - log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) - return "" - } + return "" - return path } diff --git a/management/server/file_store.go b/management/server/file_store.go index 561e133cec8..f375fb99062 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { // It is recommended to call it with locking FileStore.mux func (s *FileStore) persist(ctx context.Context, file string) error { start := time.Now() - err := util.WriteJson(file, s) + err := util.WriteJson(context.Background(), file, s) if err != nil { return err } diff --git a/management/server/group.go b/management/server/group.go index 154a33b1350..a36213f0493 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -6,11 +6,12 @@ import ( "fmt" "slices" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" @@ -27,11 +28,6 @@ func (e *GroupLinkError) Error() string { // CheckGroupPermissions validates if a user has the necessary permissions to view groups func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err @@ -41,7 +37,7 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return status.NewUserNotPartOfAccountError() } - if user.IsRegularUser() && settings.RegularUsersViewBlocked { + if user.IsRegularUser() { return status.NewAdminPermissionError() } diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index a5856a0e43c..f5027cd7798 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(account.Peers)) - for _, peer := range account.Peers { + peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupsMap := map[string]*nbgroup.Group{} + groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + for _, group := range groups { + groupsMap[group.ID] = group + } + + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { - var groupsInfo []api.GroupMinimum + groupsInfo := []api.GroupMinimum{} groupsChecked := make(map[string]struct{}) for _, group := range groups { _, ok := groupsChecked[group.ID] diff --git a/relay/server/peer.go b/relay/server/peer.go index c909c35d542..f65fb786afc 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -16,6 +16,8 @@ import ( const ( bufferSize = 8820 + + errCloseConn = "failed to close connection to peer: %s" ) // Peer represents a peer connection @@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) * // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // the message accordingly. func (p *Peer) Work() { + defer func() { + if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + p.log.Errorf(errCloseConn, err) + } + }() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc * case messages.MsgTypeClose: p.log.Infof("peer exited gracefully") if err := p.conn.Close(); err != nil { - log.Errorf("failed to close connection to peer: %s", err) + log.Errorf(errCloseConn, err) } default: p.log.Warnf("received unexpected message type: %s", msgType) @@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) { p.log.Errorf("failed to send close message to peer: %s", p.String()) } - err = p.conn.Close() - if err != nil { - p.log.Errorf("failed to close connection to peer: %s", err) + if err := p.conn.Close(); err != nil { + p.log.Errorf(errCloseConn, err) } } @@ -132,7 +139,7 @@ func (p *Peer) Close() { defer p.connMu.Unlock() if err := p.conn.Close(); err != nil { - p.log.Errorf("failed to close connection to peer: %s", err) + p.log.Errorf(errCloseConn, err) } } diff --git a/util/file.go b/util/file.go index ecaecd22260..4641cc1b825 100644 --- a/util/file.go +++ b/util/file.go @@ -15,7 +15,7 @@ import ( ) // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory -func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { +func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err @@ -26,18 +26,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // WriteJson writes JSON config object to a file creating parent directories if required // The output JSON is pretty-formatted -func WriteJson(file string, obj interface{}) error { +func WriteJson(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file @@ -79,7 +79,11 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { return nil } -func writeJson(file string, obj interface{}, configDir string, configFileName string) error { +func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { + // Check context before expensive operations + if ctx.Err() != nil { + return ctx.Err() + } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") @@ -87,6 +91,10 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st return err } + if ctx.Err() != nil { + return ctx.Err() + } + tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { return err @@ -111,6 +119,11 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st return err } + // Check context again + if ctx.Err() != nil { + return ctx.Err() + } + err = os.Rename(tempFileName, file) if err != nil { return err diff --git a/util/file_test.go b/util/file_test.go index 566d8eda6fb..f8c9dfabbc3 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,6 +1,7 @@ package util import ( + "context" "crypto/md5" "encoding/hex" "io" @@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tmpDir := t.TempDir() - err := WriteJson(tmpDir+"/testconfig.json", tt.config) + err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config) require.NoError(t, err) read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) @@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) { src := tmpDir + "/copytest_src" dst := tmpDir + "/copytest_dst" - err := WriteJson(src, tt.srcContent) + err := WriteJson(context.Background(), src, tt.srcContent) require.NoError(t, err) err = CopyFileContents(src, dst) @@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) { _ = os.Remove(cfgFile) }() - err := WriteJson(cfgFile, tt.config) + err := WriteJson(context.Background(), cfgFile, tt.config) require.NoError(t, err) read, err := ReadJson(cfgFile, &TestConfig{})