From f9723c9266e17aeafcc7d05d6b0e63f605bbea2f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 29 Nov 2024 17:50:35 +0100 Subject: [PATCH 01/10] [client] Account different policiy rules for routes firewall rules (#2939) * Account different policies rules for routes firewall rules This change ensures that route firewall rules will consider source group peers in the rules generation for access control policies. This fixes the behavior where multiple policies with different levels of access was being applied to all peers in a distribution group * split function * avoid unnecessary allocation Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com> --------- Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com> --- management/server/route.go | 79 ++++++++++++++++--- management/server/route_test.go | 132 +++++++++++++++++++++++++++++++- 2 files changed, 196 insertions(+), 15 deletions(-) diff --git a/management/server/route.go b/management/server/route.go index ecb562645e6..23bea87e3b8 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -417,25 +417,82 @@ func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, continue } - policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) - for _, policy := range policies { - if !policy.Enabled { + distributionPeers := a.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup}) + rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { continue } - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } + rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} - distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN) - routesFirewallRules = append(routesFirewallRules, rules...) +func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + if distPeer && valid { + distPeersWithPolicy[pID] = struct{}{} } } } - return routesFirewallRules + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := a.Peers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers } func getDefaultPermit(route *route.Route) []*RouteFirewallRule { diff --git a/management/server/route_test.go b/management/server/route_test.go index 108f791e02c..8bf9a3aebb3 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "sort" "testing" "time" @@ -1486,6 +1487,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { peerBIp = "100.65.80.39" peerCIp = "100.65.254.139" peerHIp = "100.65.29.55" + peerJIp = "100.65.29.65" + peerKIp = "100.65.29.66" ) account := &Account{ @@ -1541,6 +1544,16 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { IP: net.ParseIP(peerHIp), Status: &nbpeer.PeerStatus{}, }, + "peerJ": { + ID: "peerJ", + IP: net.ParseIP(peerJIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerK": { + ID: "peerK", + IP: net.ParseIP(peerKIp), + Status: &nbpeer.PeerStatus{}, + }, }, Groups: map[string]*nbgroup.Group{ "routingPeer1": { @@ -1567,6 +1580,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Name: "Route2", Peers: []string{}, }, + "route4": { + ID: "route4", + Name: "route4", + Peers: []string{}, + }, "finance": { ID: "finance", Name: "Finance", @@ -1584,6 +1602,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { "peerB", }, }, + "qa": { + ID: "qa", + Name: "QA", + Peers: []string{ + "peerJ", + "peerK", + }, + }, + "restrictQA": { + ID: "restrictQA", + Name: "restrictQA", + Peers: []string{ + "peerJ", + }, + }, + "unrestrictedQA": { + ID: "unrestrictedQA", + Name: "unrestrictedQA", + Peers: []string{ + "peerK", + }, + }, "contractors": { ID: "contractors", Name: "Contractors", @@ -1631,6 +1671,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Groups: []string{"contractors"}, AccessControlGroups: []string{}, }, + "route4": { + ID: "route4", + Network: netip.MustParsePrefix("192.168.10.0/16"), + NetID: "route4", + NetworkType: route.IPv4Network, + PeerGroups: []string{"routingPeer1"}, + Description: "Route4", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"qa"}, + AccessControlGroups: []string{"route4"}, + }, }, Policies: []*Policy{ { @@ -1685,6 +1738,49 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, }, }, + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + Ports: []string{"80"}, + Sources: []string{ + "restrictQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + Sources: []string{ + "unrestrictedQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, }, } @@ -1709,7 +1805,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { t.Run("check peer routes firewall rules", func(t *testing.T) { routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) - assert.Len(t, routesFirewallRules, 2) + assert.Len(t, routesFirewallRules, 4) expectedRoutesFirewallRules := []*RouteFirewallRule{ { @@ -1735,12 +1831,32 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Port: 320, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + additionalFirewallRule := []*RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerJIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "tcp", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerKIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "all", + }, + } + + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) @@ -1769,7 +1885,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { IsDynamic: true, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerC is part of route1 distribution groups but should not receive the routes firewall rules routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) @@ -1778,6 +1894,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { } +// orderList is a helper function to sort a list of strings +func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule { + for _, rule := range ruleList { + sort.Strings(rule.SourceRanges) + } + return ruleList +} + func TestRouteAccountPeersUpdate(t *testing.T) { manager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") From e52d352a48aac55caada063ef44bd7e79e173bdb Mon Sep 17 00:00:00 2001 From: v1rusnl <18641204+v1rusnl@users.noreply.github.com> Date: Sat, 30 Nov 2024 10:26:31 +0100 Subject: [PATCH 02/10] Update Caddyfile and Docker Compose to support HTTP3 (#2822) --- infrastructure_files/getting-started-with-zitadel.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 0b2b651429e..7793d1fda1d 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -530,7 +530,7 @@ renderCaddyfile() { { debug servers :80,:443 { - protocols h1 h2c + protocols h1 h2c h3 } } @@ -788,6 +788,7 @@ services: networks: [ netbird ] ports: - '443:443' + - '443:443/udp' - '80:80' - '8080:8080' volumes: From e4a5fb3e91bfbe202d3f13f66599f6ebfd9fe77e Mon Sep 17 00:00:00 2001 From: victorserbu2709 Date: Sat, 30 Nov 2024 11:34:52 +0200 Subject: [PATCH 03/10] Unspecified address: default NetworkTypeUDP4+NetworkTypeUDP6 (#2804) --- client/iface/bind/udp_mux.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go index 12f7a812970..00a91f0ecad 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/bind/udp_mux.go @@ -162,12 +162,13 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { - case addr.IP.To4() != nil: - networks = []ice.NetworkType{ice.NetworkTypeUDP4} case addr.IP.To16() != nil: networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} + case addr.IP.To4() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4} + default: params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) } From ecb44ff3065156ba0d548b4c4fe04c927b0d947e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sun, 1 Dec 2024 19:22:52 +0100 Subject: [PATCH 04/10] [client] Add pprof build tag (#2964) * Add pprof build tag * Change env handling --- client/cmd/pprof.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 client/cmd/pprof.go diff --git a/client/cmd/pprof.go b/client/cmd/pprof.go new file mode 100644 index 00000000000..37efd35f0cd --- /dev/null +++ b/client/cmd/pprof.go @@ -0,0 +1,33 @@ +//go:build pprof +// +build pprof + +package cmd + +import ( + "net/http" + _ "net/http/pprof" + "os" + + log "github.com/sirupsen/logrus" +) + +func init() { + addr := pprofAddr() + go pprof(addr) +} + +func pprofAddr() string { + listenAddr := os.Getenv("NB_PPROF_ADDR") + if listenAddr == "" { + return "localhost:6969" + } + + return listenAddr +} + +func pprof(listenAddr string) { + log.Infof("listening pprof on: %s\n", listenAddr) + if err := http.ListenAndServe(listenAddr, nil); err != nil { + log.Fatalf("Failed to start pprof: %v", err) + } +} From 5142dc52c11ad9841fe8ae229ed585d27f976503 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 2 Dec 2024 17:55:02 +0100 Subject: [PATCH 05/10] [client] Persist route selection (#2810) --- client/firewall/iptables/rulestore_linux.go | 10 ++ client/internal/engine.go | 13 +- client/internal/engine_test.go | 5 +- client/internal/routemanager/manager.go | 38 ++++- client/internal/routemanager/manager_test.go | 4 +- client/internal/routemanager/mock.go | 2 +- .../routemanager/refcounter/refcounter.go | 13 ++ client/internal/routemanager/state.go | 19 +++ .../internal/routeselector/routeselector.go | 67 +++++++++ client/internal/statemanager/manager.go | 142 ++++++++++++++---- 10 files changed, 273 insertions(+), 40 deletions(-) create mode 100644 client/internal/routemanager/state.go diff --git a/client/firewall/iptables/rulestore_linux.go b/client/firewall/iptables/rulestore_linux.go index bfd08bee27d..004c512a4b4 100644 --- a/client/firewall/iptables/rulestore_linux.go +++ b/client/firewall/iptables/rulestore_linux.go @@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error { return err } s.ips = temp.IPs + + if temp.IPs == nil { + temp.IPs = make(map[string]struct{}) + } + return nil } @@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error { return err } s.ipsets = temp.IPSets + + if temp.IPSets == nil { + temp.IPSets = make(map[string]*ipList) + } + return nil } diff --git a/client/internal/engine.go b/client/internal/engine.go index dc4499e17f3..920c295cdde 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -349,8 +349,17 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer - e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) + e.routeManager = routemanager.NewManager( + e.ctx, + e.config.WgPrivateKey.PublicKey().String(), + e.config.DNSRouteInterval, + e.wgInterface, + e.statusRecorder, + e.relayManager, + initialRoutes, + e.stateManager, + ) + beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { log.Errorf("Failed to initialize route manager: %s", err) } else { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index b6c6186ea83..b58c1f7e93a 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -245,12 +245,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { nil) wgIface := &iface.MockWGIface{ + NameFunc: func() string { return "utun102" }, RemovePeerFunc: func(peerKey string) error { return nil }, } engine.wgInterface = wgIface - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil) + _, _, err = engine.routeManager.Init() + require.NoError(t, err) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 0a1c7dc56b8..f1c4ae5ef75 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -32,7 +32,7 @@ import ( // Manager is a route manager interface type Manager interface { - Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector @@ -59,6 +59,7 @@ type DefaultManager struct { routeRefCounter *refcounter.RouteRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration + stateManager *statemanager.Manager } func NewManager( @@ -69,6 +70,7 @@ func NewManager( statusRecorder *peer.Status, relayMgr *relayClient.Manager, initialRoutes []*route.Route, + stateManager *statemanager.Manager, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) notifier := notifier.NewNotifier() @@ -80,12 +82,12 @@ func NewManager( dnsRouteInterval: dnsRouteInterval, clientNetworks: make(map[route.HAUniqueID]*clientNetwork), relayMgr: relayMgr, - routeSelector: routeselector.NewRouteSelector(), sysOps: sysOps, statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, notifier: notifier, + stateManager: stateManager, } dm.routeRefCounter = refcounter.New( @@ -121,7 +123,7 @@ func NewManager( } // Init sets up the routing -func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { if nbnet.CustomRoutingDisabled() { return nil, nil, nil } @@ -137,14 +139,38 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) + beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) if err != nil { return nil, nil, fmt.Errorf("setup routing: %w", err) } + + m.routeSelector = m.initSelector() + log.Info("Routing setup complete") return beforePeerHook, afterPeerHook, nil } +func (m *DefaultManager) initSelector() *routeselector.RouteSelector { + var state *SelectorState + m.stateManager.RegisterState(state) + + // restore selector state if it exists + if err := m.stateManager.LoadState(state); err != nil { + log.Warnf("failed to load state: %v", err) + return routeselector.NewRouteSelector() + } + + if state := m.stateManager.GetState(state); state != nil { + if selector, ok := state.(*SelectorState); ok { + return (*routeselector.RouteSelector)(selector) + } + + log.Warnf("failed to convert state with type %T to SelectorState", state) + } + + return routeselector.NewRouteSelector() +} + func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -252,6 +278,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { go clientNetworkWatcher.peersStateAndUpdateWatcher() clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) } + + if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil { + log.Errorf("failed to update state: %v", err) + } } // stopObsoleteClients stops the client network watcher for the networks that are not in the new list diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index e669bc44a08..07dac21b819 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) + routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil) - _, _, err = routeManager.Init(nil) + _, _, err = routeManager.Init() require.NoError(t, err, "should init route manager") defer routeManager.Stop(nil) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 503185f0311..556a6235138 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -21,7 +21,7 @@ type MockManager struct { StopFunc func(manager *statemanager.Manager) } -func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { +func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index f2f0a169df0..27a724f5062 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -71,11 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key } // LoadData loads the data from the existing counter +// The passed counter should not be used any longer after calling this function. func (rm *Counter[Key, I, O]) LoadData( existingCounter *Counter[Key, I, O], ) { rm.mu.Lock() defer rm.mu.Unlock() + existingCounter.mu.Lock() + defer existingCounter.mu.Unlock() rm.refCountMap = existingCounter.refCountMap rm.idMap = existingCounter.idMap @@ -231,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the json.Unmarshaler interface for Counter. func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { + rm.mu.Lock() + defer rm.mu.Unlock() + var temp struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` IDMap map[string][]Key `json:"idMap"` @@ -241,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { rm.refCountMap = temp.RefCountMap rm.idMap = temp.IDMap + if temp.RefCountMap == nil { + temp.RefCountMap = map[Key]Ref[O]{} + } + if temp.IDMap == nil { + temp.IDMap = map[string][]Key{} + } + return nil } diff --git a/client/internal/routemanager/state.go b/client/internal/routemanager/state.go new file mode 100644 index 00000000000..a45c32b506d --- /dev/null +++ b/client/internal/routemanager/state.go @@ -0,0 +1,19 @@ +package routemanager + +import ( + "github.com/netbirdio/netbird/client/internal/routeselector" +) + +type SelectorState routeselector.RouteSelector + +func (s *SelectorState) Name() string { + return "routeselector_state" +} + +func (s *SelectorState) MarshalJSON() ([]byte, error) { + return (*routeselector.RouteSelector)(s).MarshalJSON() +} + +func (s *SelectorState) UnmarshalJSON(data []byte) error { + return (*routeselector.RouteSelector)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 00128a27b03..2874604fdd3 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -1,8 +1,10 @@ package routeselector import ( + "encoding/json" "fmt" "slices" + "sync" "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" @@ -12,6 +14,7 @@ import ( ) type RouteSelector struct { + mu sync.RWMutex selectedRoutes map[route.NetID]struct{} selectAll bool } @@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector { // SelectRoutes updates the selected routes based on the provided route IDs. func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error { + rs.mu.Lock() + defer rs.mu.Unlock() + if !appendRoute { rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -46,6 +52,9 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al // SelectAllRoutes sets the selector to select all routes. func (rs *RouteSelector) SelectAllRoutes() { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.selectAll = true rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -53,6 +62,9 @@ func (rs *RouteSelector) SelectAllRoutes() { // DeselectRoutes removes specific routes from the selection. // If the selector is in "select all" mode, it will transition to "select specific" mode. func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.selectAll { rs.selectAll = false rs.selectedRoutes = map[route.NetID]struct{}{} @@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route. // DeselectAllRoutes deselects all routes, effectively disabling route selection. func (rs *RouteSelector) DeselectAllRoutes() { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.selectAll = false rs.selectedRoutes = map[route.NetID]struct{}{} } // IsSelected checks if a specific route is selected. func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { + rs.mu.RLock() + defer rs.mu.RUnlock() + if rs.selectAll { return true } @@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { // FilterSelected removes unselected routes from the provided map. func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { + rs.mu.RLock() + defer rs.mu.RUnlock() + if rs.selectAll { return maps.Clone(routes) } @@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { } return filtered } + +// MarshalJSON implements the json.Marshaler interface +func (rs *RouteSelector) MarshalJSON() ([]byte, error) { + rs.mu.RLock() + defer rs.mu.RUnlock() + + return json.Marshal(struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + }{ + SelectAll: rs.selectAll, + SelectedRoutes: rs.selectedRoutes, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +// If the JSON is empty or null, it will initialize like a NewRouteSelector. +func (rs *RouteSelector) UnmarshalJSON(data []byte) error { + rs.mu.Lock() + defer rs.mu.Unlock() + + // Check for null or empty JSON + if len(data) == 0 || string(data) == "null" { + rs.selectedRoutes = map[route.NetID]struct{}{} + rs.selectAll = true + return nil + } + + var temp struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + rs.selectedRoutes = temp.SelectedRoutes + rs.selectAll = temp.SelectAll + + if rs.selectedRoutes == nil { + rs.selectedRoutes = map[route.NetID]struct{}{} + } + + return nil +} diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index da6dd022fc2..aae73b79f5e 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -22,9 +22,28 @@ import ( // State interface defines the methods that all state types must implement type State interface { Name() string +} + +// CleanableState interface extends State with cleanup capability +type CleanableState interface { + State Cleanup() error } +// RawState wraps raw JSON data for unregistered states +type RawState struct { + data json.RawMessage +} + +func (r *RawState) Name() string { + return "" // This is a placeholder implementation +} + +// MarshalJSON implements json.Marshaler to preserve the original JSON +func (r *RawState) MarshalJSON() ([]byte, error) { + return r.data, nil +} + // Manager handles the persistence and management of various states type Manager struct { mu sync.Mutex @@ -209,15 +228,15 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } -// loadState loads the existing state from the state file -func (m *Manager) loadState() error { +// loadStateFile reads and unmarshals the state file into a map of raw JSON messages +func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) { data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { log.Debug("state file does not exist") - return nil + return nil, nil // nolint:nilnil } - return fmt.Errorf("read state file: %w", err) + return nil, fmt.Errorf("read state file: %w", err) } var rawStates map[string]json.RawMessage @@ -228,37 +247,69 @@ func (m *Manager) loadState() error { } else { log.Info("State file deleted") } - return fmt.Errorf("unmarshal states: %w", err) + return nil, fmt.Errorf("unmarshal states: %w", err) } - var merr *multierror.Error + return rawStates, nil +} - for name, rawState := range rawStates { - stateType, ok := m.stateTypes[name] - if !ok { - merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) - continue - } +// loadSingleRawState unmarshals a raw state into a concrete state object +func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) { + stateType, ok := m.stateTypes[name] + if !ok { + return nil, fmt.Errorf("state %s not registered", name) + } - if string(rawState) == "null" { - continue - } + if string(rawState) == "null" { + return nil, nil //nolint:nilnil + } - statePtr := reflect.New(stateType).Interface().(State) - if err := json.Unmarshal(rawState, statePtr); err != nil { - merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) - continue - } + statePtr := reflect.New(stateType).Interface().(State) + if err := json.Unmarshal(rawState, statePtr); err != nil { + return nil, fmt.Errorf("unmarshal state %s: %w", name, err) + } + + return statePtr, nil +} + +// LoadState loads a specific state from the state file +func (m *Manager) LoadState(state State) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() - m.states[name] = statePtr + rawStates, err := m.loadStateFile() + if err != nil { + return err + } + if rawStates == nil { + return nil + } + + name := state.Name() + rawState, exists := rawStates[name] + if !exists { + return nil + } + + loadedState, err := m.loadSingleRawState(name, rawState) + if err != nil { + return err + } + + m.states[name] = loadedState + if loadedState != nil { log.Debugf("loaded state: %s", name) } - return nberrors.FormatErrorOrNil(merr) + return nil } -// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. -// If the cleanup is successful, the state is marked for deletion. +// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it. +// Unregistered states are preserved in their original state. func (m *Manager) PerformCleanup() error { if m == nil { return nil @@ -267,22 +318,53 @@ func (m *Manager) PerformCleanup() error { m.mu.Lock() defer m.mu.Unlock() - if err := m.loadState(); err != nil { + // Load raw states from file + rawStates, err := m.loadStateFile() + if err != nil { log.Warnf("Failed to load state during cleanup: %v", err) + return err + } + if rawStates == nil { + return nil } var merr *multierror.Error - for name, state := range m.states { - if state == nil { - // If no state was found in the state file, we don't mark the state dirty nor return an error + + // Process each state in the file + for name, rawState := range rawStates { + // For unregistered states, preserve the raw JSON + if _, registered := m.stateTypes[name]; !registered { + m.states[name] = &RawState{data: rawState} + continue + } + + // Load the registered state + loadedState, err := m.loadSingleRawState(name, rawState) + if err != nil { + merr = multierror.Append(merr, err) + continue + } + + if loadedState == nil { + continue + } + + // Check if state supports cleanup + cleanableState, isCleanable := loadedState.(CleanableState) + if !isCleanable { + // If it doesn't support cleanup, keep it as-is + m.states[name] = loadedState continue } + // Perform cleanup for cleanable states log.Infof("client was not shut down properly, cleaning up %s", name) - if err := state.Cleanup(); err != nil { + if err := cleanableState.Cleanup(); err != nil { merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) + // On cleanup error, preserve the state + m.states[name] = loadedState } else { - // mark for deletion on cleanup success + // Successfully cleaned up - mark for deletion m.states[name] = nil m.dirty[name] = struct{}{} } From c7e7ad5030cd695f1ba1ad2c819b2a8a551a90d8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 2 Dec 2024 18:04:02 +0100 Subject: [PATCH 06/10] [client] Add state file to debug bundle (#2969) --- client/anonymize/anonymize.go | 4 +- client/anonymize/anonymize_test.go | 8 + client/server/debug.go | 178 +++++++++++++++++--- client/server/debug_test.go | 258 +++++++++++++++++++++++++++++ 4 files changed, 425 insertions(+), 23 deletions(-) create mode 100644 client/server/debug_test.go diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 7ebe0442dfa..ad31682f25a 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -152,9 +152,9 @@ func (a *Anonymizer) AnonymizeString(str string) string { return str } -// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes. +// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes. func (a *Anonymizer) AnonymizeSchemeURI(text string) string { - re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`) + re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`) return re.ReplaceAllStringFunc(text, a.AnonymizeURI) } diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index e660749ec5d..605788ab54a 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -140,8 +140,16 @@ func TestAnonymizeSchemeURI(t *testing.T) { expect string }{ {"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`}, + {"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`}, {"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`}, + {"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`}, + {"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`}, + {"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, {"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, + {"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`}, + {"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`}, + {"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`}, + {"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`}, } for _, tc := range tests { diff --git a/client/server/debug.go b/client/server/debug.go index 5ed43293b4a..1bad907ba56 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -5,9 +5,13 @@ package server import ( "archive/zip" "bufio" + "bytes" "context" + "encoding/json" + "errors" "fmt" "io" + "io/fs" "net" "net/netip" "os" @@ -20,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/proto" ) @@ -31,6 +36,7 @@ client.log: Most recent, anonymized log file of the NetBird client. routes.txt: Anonymized system routes, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. +state.json: Anonymized client state dump containing netbird states. Anonymization Process @@ -50,8 +56,22 @@ Domains All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle. Reoccuring domain names are replaced with the same anonymized domain. +State File +The state.json file contains anonymized internal state information of the NetBird client, including: +- DNS settings and configuration +- Firewall rules +- Exclusion routes +- Route selection +- Other internal states that may be present + +The state file follows the same anonymization rules as other files: +- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure +- Domain names are consistently anonymized +- Technical identifiers and non-sensitive data remain unchanged + Routes For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. + Network Interfaces The interfaces.txt file contains information about network interfaces, including: - Interface name @@ -132,6 +152,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques } } + if err := s.addStateFile(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add state file to debug bundle: %v", err) + } + if err := s.addLogfile(req, anonymizer, archive); err != nil { return fmt.Errorf("add log file: %w", err) } @@ -248,6 +272,44 @@ func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonym return nil } +func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return fmt.Errorf("read state file: %w", err) + } + + if req.GetAnonymize() { + var rawStates map[string]json.RawMessage + if err := json.Unmarshal(data, &rawStates); err != nil { + return fmt.Errorf("unmarshal states: %w", err) + } + + if err := anonymizeStateFile(&rawStates, anonymizer); err != nil { + return fmt.Errorf("anonymize state file: %w", err) + } + + bs, err := json.MarshalIndent(rawStates, "", " ") + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + data = bs + } + + if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil { + return fmt.Errorf("add state file to zip: %w", err) + } + + return nil +} + func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) { logFile, err := os.Open(s.logFile) if err != nil { @@ -264,7 +326,7 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize var writer *io.PipeWriter logReader, writer = io.Pipe() - go s.anonymize(logFile, writer, anonymizer) + go anonymizeLog(logFile, writer, anonymizer) } else { logReader = logFile } @@ -275,26 +337,6 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize return nil } -func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { - defer func() { - // always nil - _ = writer.Close() - }() - - scanner := bufio.NewScanner(reader) - for scanner.Scan() { - line := anonymizer.AnonymizeString(scanner.Text()) - if _, err := writer.Write([]byte(line + "\n")); err != nil { - writer.CloseWithError(fmt.Errorf("anonymize write: %w", err)) - return - } - } - if err := scanner.Err(); err != nil { - writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) - return - } -} - // GetLogLevel gets the current logging level for the server. func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) { level := ParseLogLevel(log.GetLevel().String()) @@ -458,6 +500,26 @@ func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *an return builder.String() } +func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { + defer func() { + // always nil + _ = writer.Close() + }() + + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + line := anonymizer.AnonymizeString(scanner.Text()) + if _, err := writer.Write([]byte(line + "\n")); err != nil { + writer.CloseWithError(fmt.Errorf("anonymize write: %w", err)) + return + } + } + if err := scanner.Err(); err != nil { + writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) + return + } +} + func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string { anonymizedIPs := make([]string, len(ips)) for i, ip := range ips { @@ -484,3 +546,77 @@ func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []s } return anonymizedIPs } + +func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error { + for name, rawState := range *rawStates { + if string(rawState) == "null" { + continue + } + + var state map[string]any + if err := json.Unmarshal(rawState, &state); err != nil { + return fmt.Errorf("unmarshal state %s: %w", name, err) + } + + state = anonymizeValue(state, anonymizer).(map[string]any) + + bs, err := json.Marshal(state) + if err != nil { + return fmt.Errorf("marshal state %s: %w", name, err) + } + + (*rawStates)[name] = bs + } + + return nil +} + +func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any { + switch v := value.(type) { + case string: + return anonymizeString(v, anonymizer) + case map[string]any: + return anonymizeMap(v, anonymizer) + case []any: + return anonymizeSlice(v, anonymizer) + } + return value +} + +func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string { + if prefix, err := netip.ParsePrefix(v); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + return fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + if ip, err := netip.ParseAddr(v); err == nil { + return anonymizer.AnonymizeIP(ip).String() + } + return anonymizer.AnonymizeString(v) +} + +func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any { + result := make(map[string]any, len(v)) + for key, val := range v { + newKey := anonymizeMapKey(key, anonymizer) + result[newKey] = anonymizeValue(val, anonymizer) + } + return result +} + +func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string { + if prefix, err := netip.ParsePrefix(key); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + return fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + if ip, err := netip.ParseAddr(key); err == nil { + return anonymizer.AnonymizeIP(ip).String() + } + return key +} + +func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any { + for i, val := range v { + v[i] = anonymizeValue(val, anonymizer) + } + return v +} diff --git a/client/server/debug_test.go b/client/server/debug_test.go new file mode 100644 index 00000000000..303e5e66166 --- /dev/null +++ b/client/server/debug_test.go @@ -0,0 +1,258 @@ +package server + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/anonymize" +) + +func TestAnonymizeStateFile(t *testing.T) { + testState := map[string]json.RawMessage{ + "null_state": json.RawMessage("null"), + "test_state": mustMarshal(map[string]any{ + // Test simple fields + "public_ip": "203.0.113.1", + "private_ip": "192.168.1.1", + "protected_ip": "100.64.0.1", + "well_known_ip": "8.8.8.8", + "ipv6_addr": "2001:db8::1", + "private_ipv6": "fd00::1", + "domain": "test.example.com", + "uri": "stun:stun.example.com:3478", + "uri_with_ip": "turn:203.0.113.1:3478", + "netbird_domain": "device.netbird.cloud", + + // Test CIDR ranges + "public_cidr": "203.0.113.0/24", + "private_cidr": "192.168.0.0/16", + "protected_cidr": "100.64.0.0/10", + "ipv6_cidr": "2001:db8::/32", + "private_ipv6_cidr": "fd00::/8", + + // Test nested structures + "nested": map[string]any{ + "ip": "203.0.113.2", + "domain": "nested.example.com", + "more_nest": map[string]any{ + "ip": "203.0.113.3", + "domain": "deep.example.com", + }, + }, + + // Test arrays + "string_array": []any{ + "203.0.113.4", + "test1.example.com", + "test2.example.com", + }, + "object_array": []any{ + map[string]any{ + "ip": "203.0.113.5", + "domain": "array1.example.com", + }, + map[string]any{ + "ip": "203.0.113.6", + "domain": "array2.example.com", + }, + }, + + // Test multiple occurrences of same value + "duplicate_ip": "203.0.113.1", // Same as public_ip + "duplicate_domain": "test.example.com", // Same as domain + + // Test URIs with various schemes + "stun_uri": "stun:stun.example.com:3478", + "turns_uri": "turns:turns.example.com:5349", + "http_uri": "http://web.example.com:80", + "https_uri": "https://secure.example.com:443", + + // Test strings that might look like IPs but aren't + "not_ip": "300.300.300.300", + "partial_ip": "192.168", + "ip_like_string": "1234.5678", + + // Test mixed content strings + "mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80", + + // Test empty and special values + "empty_string": "", + "null_value": nil, + "numeric_value": 42, + "boolean_value": true, + }), + "route_state": mustMarshal(map[string]any{ + "routes": []any{ + map[string]any{ + "network": "203.0.113.0/24", + "gateway": "203.0.113.1", + "domains": []any{ + "route1.example.com", + "route2.example.com", + }, + }, + map[string]any{ + "network": "2001:db8::/32", + "gateway": "2001:db8::1", + "domains": []any{ + "route3.example.com", + "route4.example.com", + }, + }, + }, + // Test map with IP/CIDR keys + "refCountMap": map[string]any{ + "203.0.113.1/32": map[string]any{ + "Count": 1, + "Out": map[string]any{ + "IP": "192.168.0.1", + "Intf": map[string]any{ + "Name": "eth0", + "Index": 1, + }, + }, + }, + "2001:db8::1/128": map[string]any{ + "Count": 1, + "Out": map[string]any{ + "IP": "fe80::1", + "Intf": map[string]any{ + "Name": "eth0", + "Index": 1, + }, + }, + }, + "10.0.0.1/32": map[string]any{ // private IP should remain unchanged + "Count": 1, + "Out": map[string]any{ + "IP": "192.168.0.1", + }, + }, + }, + }), + } + + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + err := anonymizeStateFile(&testState, anonymizer) + require.NoError(t, err) + + // Helper function to unmarshal and get nested values + var state map[string]any + err = json.Unmarshal(testState["test_state"], &state) + require.NoError(t, err) + + // Test null state remains unchanged + require.Equal(t, "null", string(testState["null_state"])) + + // Basic assertions + assert.NotEqual(t, "203.0.113.1", state["public_ip"]) + assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged + assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged + assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged + assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"]) + assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged + assert.NotEqual(t, "test.example.com", state["domain"]) + assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain")) + assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged + + // CIDR ranges + assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"]) + assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved + assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged + assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged + assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"]) + assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved + + // Nested structures + nested := state["nested"].(map[string]any) + assert.NotEqual(t, "203.0.113.2", nested["ip"]) + assert.NotEqual(t, "nested.example.com", nested["domain"]) + moreNest := nested["more_nest"].(map[string]any) + assert.NotEqual(t, "203.0.113.3", moreNest["ip"]) + assert.NotEqual(t, "deep.example.com", moreNest["domain"]) + + // Arrays + strArray := state["string_array"].([]any) + assert.NotEqual(t, "203.0.113.4", strArray[0]) + assert.NotEqual(t, "test1.example.com", strArray[1]) + assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain")) + + objArray := state["object_array"].([]any) + firstObj := objArray[0].(map[string]any) + assert.NotEqual(t, "203.0.113.5", firstObj["ip"]) + assert.NotEqual(t, "array1.example.com", firstObj["domain"]) + + // Duplicate values should be anonymized consistently + assert.Equal(t, state["public_ip"], state["duplicate_ip"]) + assert.Equal(t, state["domain"], state["duplicate_domain"]) + + // URIs + assert.NotContains(t, state["stun_uri"], "stun.example.com") + assert.NotContains(t, state["turns_uri"], "turns.example.com") + assert.NotContains(t, state["http_uri"], "web.example.com") + assert.NotContains(t, state["https_uri"], "secure.example.com") + + // Non-IP strings should remain unchanged + assert.Equal(t, "300.300.300.300", state["not_ip"]) + assert.Equal(t, "192.168", state["partial_ip"]) + assert.Equal(t, "1234.5678", state["ip_like_string"]) + + // Mixed content should have IPs and domains replaced + mixedContent := state["mixed_content"].(string) + assert.NotContains(t, mixedContent, "203.0.113.1") + assert.NotContains(t, mixedContent, "test.example.com") + assert.Contains(t, mixedContent, "Server at ") + assert.Contains(t, mixedContent, " on port 80") + + // Special values should remain unchanged + assert.Equal(t, "", state["empty_string"]) + assert.Nil(t, state["null_value"]) + assert.Equal(t, float64(42), state["numeric_value"]) + assert.Equal(t, true, state["boolean_value"]) + + // Check route state + var routeState map[string]any + err = json.Unmarshal(testState["route_state"], &routeState) + require.NoError(t, err) + + routes := routeState["routes"].([]any) + route1 := routes[0].(map[string]any) + assert.NotEqual(t, "203.0.113.0/24", route1["network"]) + assert.Contains(t, route1["network"], "/24") + assert.NotEqual(t, "203.0.113.1", route1["gateway"]) + domains := route1["domains"].([]any) + assert.True(t, strings.HasSuffix(domains[0].(string), ".domain")) + assert.True(t, strings.HasSuffix(domains[1].(string), ".domain")) + + // Check map keys are anonymized + refCountMap := routeState["refCountMap"].(map[string]any) + hasPublicIPKey := false + hasIPv6Key := false + hasPrivateIPKey := false + for key := range refCountMap { + if strings.Contains(key, "203.0.113.1") { + hasPublicIPKey = true + } + if strings.Contains(key, "2001:db8::1") { + hasIPv6Key = true + } + if key == "10.0.0.1/32" { + hasPrivateIPKey = true + } + } + assert.False(t, hasPublicIPKey, "public IP in key should be anonymized") + assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized") + assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged") +} + +func mustMarshal(v any) json.RawMessage { + data, err := json.Marshal(v) + if err != nil { + panic(err) + } + return data +} From dffce78a8c8e78161e8fe1f8d7f01ba232c0d8e7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:19:34 +0100 Subject: [PATCH 07/10] [client] Fix debug bundle state anonymization test (#2976) --- client/server/debug_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/client/server/debug_test.go b/client/server/debug_test.go index 303e5e66166..1515036bd02 100644 --- a/client/server/debug_test.go +++ b/client/server/debug_test.go @@ -137,6 +137,13 @@ func TestAnonymizeStateFile(t *testing.T) { } anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + + // Pre-seed the domains we need to verify in the test assertions + anonymizer.AnonymizeDomain("test.example.com") + anonymizer.AnonymizeDomain("nested.example.com") + anonymizer.AnonymizeDomain("deep.example.com") + anonymizer.AnonymizeDomain("array1.example.com") + err := anonymizeStateFile(&testState, anonymizer) require.NoError(t, err) From a0bf0bdcc077e54294c6e31a24efcd9047f3f97f Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 3 Dec 2024 10:13:27 +0100 Subject: [PATCH 08/10] Pass IP instead of net to Rosenpass (#2975) --- client/internal/peer/conn.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 81c456db747..a8de2fccb73 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -83,7 +83,6 @@ type Conn struct { signaler *Signaler relayManager *relayClient.Manager allowedIP net.IP - allowedNet string handshaker *Handshaker onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) @@ -111,7 +110,7 @@ type Conn struct { // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { - allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) + allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) return nil, err @@ -129,7 +128,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu signaler: signaler, relayManager: relayManager, allowedIP: allowedIP, - allowedNet: allowedNet.String(), statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), } @@ -594,7 +592,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr) + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr) } } From a4826cfb5fb5e0510f485644583bb14da1aa2ae8 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 3 Dec 2024 10:22:04 +0100 Subject: [PATCH 09/10] [client] Get static system info once (#2965) Get static system info once for Windows, Darwin, and Linux nodes This should improve startup and peer authentication times --- client/system/info.go | 8 ++++ client/system/info_darwin.go | 20 +++++----- client/system/info_linux.go | 61 +++++++++++++---------------- client/system/info_windows.go | 53 +++++++++++++------------ client/system/static_info.go | 46 ++++++++++++++++++++++ client/system/sysinfo_linux_test.go | 5 ++- 6 files changed, 123 insertions(+), 70 deletions(-) create mode 100644 client/system/static_info.go diff --git a/client/system/info.go b/client/system/info.go index 2af2e637b92..200d835df31 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -61,6 +61,14 @@ type Info struct { Files []File // for posture checks } +// StaticInfo is an object that contains machine information that does not change +type StaticInfo struct { + SystemSerialNumber string + SystemProductName string + SystemManufacturer string + Environment Environment +} + // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context func extractUserAgent(ctx context.Context) string { md, hasMeta := metadata.FromOutgoingContext(ctx) diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index 6f4ed173b4b..13b0a446bd3 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -10,13 +10,12 @@ import ( "os/exec" "runtime" "strings" + "time" "golang.org/x/sys/unix" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) @@ -41,11 +40,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - serialNum, prodName, manufacturer := sysInfo() - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -57,10 +55,10 @@ func GetInfo(ctx context.Context) *Info { CPUs: runtime.NumCPU(), KernelVersion: release, NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } systemHostname, _ := os.Hostname() diff --git a/client/system/info_linux.go b/client/system/info_linux.go index b6a142bce28..bfc77be1915 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -1,5 +1,4 @@ //go:build !android -// +build !android package system @@ -16,30 +15,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/zcalusic/sysinfo" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) -type SysInfoGetter interface { - GetSysInfo() SysInfo -} - -type SysInfoWrapper struct { - si sysinfo.SysInfo -} - -func (s SysInfoWrapper) GetSysInfo() SysInfo { - s.si.GetSysInfo() - return SysInfo{ - ChassisSerial: s.si.Chassis.Serial, - ProductSerial: s.si.Product.Serial, - BoardSerial: s.si.Board.Serial, - ProductName: s.si.Product.Name, - BoardName: s.si.Board.Name, - ProductVendor: s.si.Product.Vendor, - } -} +var ( + // it is override in tests + getSystemInfo = defaultSysInfoImplementation +) // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { @@ -65,12 +47,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - si := SysInfoWrapper{} - serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo()) - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -85,10 +65,10 @@ func GetInfo(ctx context.Context) *Info { UIVersion: extractUserAgent(ctx), KernelVersion: osInfo[1], NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } return gio @@ -108,9 +88,9 @@ func _getInfo() string { return out.String() } -func sysInfo(si SysInfo) (string, string, string) { +func sysInfo() (string, string, string) { isascii := regexp.MustCompile("^[[:ascii:]]+$") - + si := getSystemInfo() serials := []string{si.ChassisSerial, si.ProductSerial} serial := "" @@ -141,3 +121,16 @@ func sysInfo(si SysInfo) (string, string, string) { } return serial, name, manufacturer } + +func defaultSysInfoImplementation() SysInfo { + si := sysinfo.SysInfo{} + si.GetSysInfo() + return SysInfo{ + ChassisSerial: si.Chassis.Serial, + ProductSerial: si.Product.Serial, + BoardSerial: si.Board.Serial, + ProductName: si.Product.Name, + BoardName: si.Board.Name, + ProductVendor: si.Product.Vendor, + } +} diff --git a/client/system/info_windows.go b/client/system/info_windows.go index 68631fe164d..28bd3d3007c 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -6,13 +6,12 @@ import ( "os" "runtime" "strings" + "time" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" "golang.org/x/sys/windows/registry" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) @@ -42,24 +41,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - serialNum, err := sysNumber() - if err != nil { - log.Warnf("failed to get system serial number: %s", err) - } - - prodName, err := sysProductName() - if err != nil { - log.Warnf("failed to get system product name: %s", err) - } - - manufacturer, err := sysManufacturer() - if err != nil { - log.Warnf("failed to get system manufacturer: %s", err) - } - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -71,10 +56,10 @@ func GetInfo(ctx context.Context) *Info { CPUs: runtime.NumCPU(), KernelVersion: buildVersion, NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } systemHostname, _ := os.Hostname() @@ -85,6 +70,26 @@ func GetInfo(ctx context.Context) *Info { return gio } +func sysInfo() (serialNumber string, productName string, manufacturer string) { + var err error + serialNumber, err = sysNumber() + if err != nil { + log.Warnf("failed to get system serial number: %s", err) + } + + productName, err = sysProductName() + if err != nil { + log.Warnf("failed to get system product name: %s", err) + } + + manufacturer, err = sysManufacturer() + if err != nil { + log.Warnf("failed to get system manufacturer: %s", err) + } + + return serialNumber, productName, manufacturer +} + func getOSNameAndVersion() (string, string) { var dst []Win32_OperatingSystem query := wmi.CreateQuery(&dst, "") diff --git a/client/system/static_info.go b/client/system/static_info.go new file mode 100644 index 00000000000..fabe65a6806 --- /dev/null +++ b/client/system/static_info.go @@ -0,0 +1,46 @@ +//go:build (linux && !android) || windows || (darwin && !ios) + +package system + +import ( + "context" + "sync" + "time" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +var ( + staticInfo StaticInfo + once sync.Once +) + +func init() { + go func() { + _ = updateStaticInfo() + }() +} + +func updateStaticInfo() StaticInfo { + once.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo() + wg.Done() + }() + go func() { + staticInfo.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + go func() { + staticInfo.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Wait() + }) + return staticInfo +} diff --git a/client/system/sysinfo_linux_test.go b/client/system/sysinfo_linux_test.go index f6a0b70587b..ae89bfcf974 100644 --- a/client/system/sysinfo_linux_test.go +++ b/client/system/sysinfo_linux_test.go @@ -183,7 +183,10 @@ func Test_sysInfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo) + getSystemInfo = func() SysInfo { + return tt.sysInfo + } + gotSerialNum, gotProdName, gotManufacturer := sysInfo() if gotSerialNum != tt.wantSerialNum { t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum) } From 6285e0d23ef95d98abc815315c67a24371f7a2a4 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 3 Dec 2024 12:43:17 +0100 Subject: [PATCH 10/10] [client] Add netbird.err and netbird.out to debug bundle (#2971) --- client/server/debug.go | 45 +++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/client/server/debug.go b/client/server/debug.go index 1bad907ba56..d87fc9b6a2d 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -15,6 +15,7 @@ import ( "net" "net/netip" "os" + "path/filepath" "sort" "strings" "time" @@ -32,7 +33,9 @@ const readmeContent = `Netbird debug bundle This debug bundle contains the following files: status.txt: Anonymized status information of the NetBird client. -client.log: Most recent, anonymized log file of the NetBird client. +client.log: Most recent, anonymized client log file of the NetBird client. +netbird.err: Most recent, anonymized stderr log file of the NetBird client. +netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Anonymized system routes, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. @@ -92,6 +95,12 @@ The config.txt file contains anonymized configuration information of the NetBird Other non-sensitive configuration options are included without anonymization. ` +const ( + clientLogFile = "client.log" + errorLogFile = "netbird.err" + stdoutLogFile = "netbird.out" +) + // DebugBundle creates a debug bundle and returns the location. func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { s.mutex.Lock() @@ -310,14 +319,35 @@ func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymi return nil } -func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) { - logFile, err := os.Open(s.logFile) +func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + logDir := filepath.Dir(s.logFile) + + if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil { + return fmt.Errorf("add client log file to zip: %w", err) + } + + errLogPath := filepath.Join(logDir, errorLogFile) + if err := s.addSingleLogfile(errLogPath, errorLogFile, req, anonymizer, archive); err != nil { + log.Warnf("Failed to add %s to zip: %v", errorLogFile, err) + } + + stdoutLogPath := filepath.Join(logDir, stdoutLogFile) + if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil { + log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err) + } + + return nil +} + +// addSingleLogfile adds a single log file to the archive +func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + logFile, err := os.Open(logPath) if err != nil { - return fmt.Errorf("open log file: %w", err) + return fmt.Errorf("open log file %s: %w", targetName, err) } defer func() { if err := logFile.Close(); err != nil { - log.Errorf("Failed to close original log file: %v", err) + log.Errorf("Failed to close log file %s: %v", targetName, err) } }() @@ -330,8 +360,9 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize } else { logReader = logFile } - if err := addFileToZip(archive, logReader, "client.log"); err != nil { - return fmt.Errorf("add log file to zip: %w", err) + + if err := addFileToZip(archive, logReader, targetName); err != nil { + return fmt.Errorf("add %s to zip: %w", targetName, err) } return nil